Merge pull request #8930 from KonduitAI/master

Development updates
master
Alex Black 2020-05-11 22:57:02 +10:00 committed by GitHub
commit 9db86cec7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
56 changed files with 136 additions and 4247 deletions

View File

@ -56,8 +56,10 @@ import static org.junit.Assert.*;
*/
public class RegressionTest050 extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(300);
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
public DataType getDataType(){

View File

@ -64,6 +64,11 @@ public class RegressionTest060 extends BaseDL4JTest {
return DataType.FLOAT;
}
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void regressionTestMLP1() throws Exception {

View File

@ -64,6 +64,12 @@ public class RegressionTest071 extends BaseDL4JTest {
public DataType getDataType(){
return DataType.FLOAT;
}
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void regressionTestMLP1() throws Exception {

View File

@ -64,6 +64,11 @@ public class RegressionTest080 extends BaseDL4JTest {
return DataType.FLOAT;
}
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void regressionTestMLP1() throws Exception {

View File

@ -56,7 +56,7 @@ public class RegressionTest100a extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override

View File

@ -52,7 +52,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override

View File

@ -71,7 +71,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override

View File

@ -58,7 +58,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test

View File

@ -30,6 +30,11 @@ import static org.junit.Assert.assertTrue;
*/
public class TestDistributionDeserializer extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void testDistributionDeserializer() throws Exception {
//Test current format:

View File

@ -46,6 +46,11 @@ public class TestDeepWalk extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Override
public long getTimeoutMilliseconds() {
return 120_000L; //Increase timeout due to intermittently slow CI machines
}
@Test(timeout = 60000L)
public void testBasic() throws IOException {
//Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions

View File

@ -84,12 +84,6 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<version>4.0.2</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -27,7 +27,7 @@ import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import java.util.concurrent.atomic.AtomicLong;
/**
* Implementations of this interface should contain element-related learning algorithms. Like skip-gram, cbow or glove
* Implementations of this interface should contain element-related learning algorithms. Like skip-gram or cbow
*
* @author raver119@gmail.com
*/

View File

@ -1,427 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.embeddings.learning.impl.elements;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.glove.AbstractCoOccurrences;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.legacy.AdaGrad;
import org.nd4j.common.primitives.Counter;
import org.nd4j.common.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
/**
* GloVe LearningAlgorithm implementation for SequenceVectors
*
*
* @author raver119@gmail.com
*/
public class GloVe<T extends SequenceElement> implements ElementsLearningAlgorithm<T> {
private VocabCache<T> vocabCache;
private AbstractCoOccurrences<T> coOccurrences;
private WeightLookupTable<T> lookupTable;
private VectorsConfiguration configuration;
private AtomicBoolean isTerminate = new AtomicBoolean(false);
private INDArray syn0;
private double xMax;
private boolean shuffle;
private boolean symmetric;
protected double alpha = 0.75d;
protected double learningRate = 0.0d;
protected int maxmemory = 0;
protected int batchSize = 1000;
private AdaGrad weightAdaGrad;
private AdaGrad biasAdaGrad;
private INDArray bias;
private int workers = Runtime.getRuntime().availableProcessors();
private int vectorLength;
private static final Logger log = LoggerFactory.getLogger(GloVe.class);
@Override
public String getCodeName() {
return "GloVe";
}
@Override
public void finish() {
log.info("GloVe finalizer...");
}
@Override
public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable,
@NonNull VectorsConfiguration configuration) {
this.vocabCache = vocabCache;
this.lookupTable = lookupTable;
this.configuration = configuration;
this.syn0 = ((InMemoryLookupTable<T>) lookupTable).getSyn0();
this.vectorLength = configuration.getLayersSize();
if (this.learningRate == 0.0d)
this.learningRate = configuration.getLearningRate();
weightAdaGrad = new AdaGrad(new long[] {this.vocabCache.numWords() + 1, vectorLength}, learningRate);
bias = Nd4j.create(syn0.rows());
biasAdaGrad = new AdaGrad(bias.shape(), this.learningRate);
// maxmemory = Runtime.getRuntime().maxMemory() - (vocabCache.numWords() * vectorLength * 2 * 8);
log.info("GloVe params: {Max Memory: [" + maxmemory + "], Learning rate: [" + this.learningRate + "], Alpha: ["
+ alpha + "], xMax: [" + xMax + "], Symmetric: [" + symmetric + "], Shuffle: [" + shuffle
+ "]}");
}
/**
* pretrain is used to build CoOccurrence matrix for GloVe algorithm
* @param iterator
*/
@Override
public void pretrain(@NonNull SequenceIterator<T> iterator) {
// CoOccurence table should be built here
coOccurrences = new AbstractCoOccurrences.Builder<T>()
// TODO: symmetric should be handled via VectorsConfiguration
.symmetric(this.symmetric).windowSize(configuration.getWindow()).iterate(iterator)
.workers(workers).vocabCache(vocabCache).maxMemory(maxmemory).build();
coOccurrences.fit();
}
public double learnSequence(Sequence<T> sequence, AtomicLong nextRandom, double learningRate,
BatchSequences<T> batchSequences) {
throw new UnsupportedOperationException();
}
/**
* Learns sequence using GloVe algorithm
*
* @param sequence
* @param nextRandom
* @param learningRate
*/
@Override
public synchronized double learnSequence(@NonNull Sequence<T> sequence, @NonNull AtomicLong nextRandom,
double learningRate) {
/*
GloVe learning algorithm is implemented like a hack over settled ElementsLearningAlgorithm mechanics. It's called in SequenceVectors context, but actually only for the first call.
All subsequent calls will met early termination condition, and will be successfully ignored. But since elements vectors will be updated within first call,
this will allow compatibility with everything beyond this implementaton
*/
if (isTerminate.get())
return 0;
final AtomicLong pairsCount = new AtomicLong(0);
final Counter<Integer> errorCounter = new Counter<>();
//List<Pair<T, T>> coList = coOccurrences.coOccurrenceList();
for (int i = 0; i < configuration.getEpochs(); i++) {
// TODO: shuffle should be built in another way.
//if (shuffle)
//Collections.shuffle(coList);
Iterator<Pair<Pair<T, T>, Double>> pairs = coOccurrences.iterator();
List<GloveCalculationsThread> threads = new ArrayList<>();
for (int x = 0; x < workers; x++) {
threads.add(x, new GloveCalculationsThread(i, x, pairs, pairsCount, errorCounter));
threads.get(x).start();
}
for (int x = 0; x < workers; x++) {
try {
threads.get(x).join();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
log.info("Processed [" + pairsCount.get() + "] pairs, Error was [" + errorCounter.getCount(i) + "]");
}
isTerminate.set(true);
return 0;
}
/**
* Since GloVe is learning representations using elements CoOccurences, all training is done in GloVe class internally, so only first thread will execute learning process,
* and the rest of parent threads will just exit learning process
*
* @return True, if training should stop, False otherwise.
*/
@Override
public synchronized boolean isEarlyTerminationHit() {
return isTerminate.get();
}
private double iterateSample(T element1, T element2, double score) {
//prediction: input + bias
if (element1.getIndex() < 0 || element1.getIndex() >= syn0.rows())
throw new IllegalArgumentException("Illegal index for word " + element1.getLabel());
if (element2.getIndex() < 0 || element2.getIndex() >= syn0.rows())
throw new IllegalArgumentException("Illegal index for word " + element2.getLabel());
INDArray w1Vector = syn0.slice(element1.getIndex());
INDArray w2Vector = syn0.slice(element2.getIndex());
//w1 * w2 + bias
double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
prediction += bias.getDouble(element1.getIndex()) + bias.getDouble(element2.getIndex()) - Math.log(score);
double fDiff = (score > xMax) ? prediction : Math.pow(score / xMax, alpha) * prediction; // Math.pow(Math.min(1.0,(score / maxCount)),xMax);
// double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
if (Double.isNaN(fDiff))
fDiff = Nd4j.EPS_THRESHOLD;
//amount of change
double gradient = fDiff * learningRate;
//note the update step here: the gradient is
//the gradient of the OPPOSITE word
//for adagrad we will use the index of the word passed in
//for the gradient calculation we will use the context vector
update(element1, w1Vector, w2Vector, gradient);
update(element2, w2Vector, w1Vector, gradient);
return 0.5 * fDiff * prediction;
}
private void update(T element1, INDArray wordVector, INDArray contextVector, double gradient) {
//gradient for word vectors
INDArray grad1 = contextVector.mul(gradient);
INDArray update = weightAdaGrad.getGradient(grad1, element1.getIndex(), syn0.shape());
//update vector
wordVector.subi(update);
double w1Bias = bias.getDouble(element1.getIndex());
double biasGradient = biasAdaGrad.getGradient(gradient, element1.getIndex(), bias.shape());
double update2 = w1Bias - biasGradient;
bias.putScalar(element1.getIndex(), update2);
}
private class GloveCalculationsThread extends Thread implements Runnable {
private final int threadId;
private final int epochId;
// private final AbstractCoOccurrences<T> coOccurrences;
private final Iterator<Pair<Pair<T, T>, Double>> coList;
private final AtomicLong pairsCounter;
private final Counter<Integer> errorCounter;
public GloveCalculationsThread(int epochId, int threadId, @NonNull Iterator<Pair<Pair<T, T>, Double>> pairs,
@NonNull AtomicLong pairsCounter, @NonNull Counter<Integer> errorCounter) {
this.epochId = epochId;
this.threadId = threadId;
// this.coOccurrences = coOccurrences;
this.pairsCounter = pairsCounter;
this.errorCounter = errorCounter;
coList = pairs;
this.setName("GloVe ELA t." + this.threadId);
}
@Override
public void run() {
// int startPosition = threadId * (coList.size() / workers);
// int stopPosition = (threadId + 1) * (coList.size() / workers);
// log.info("Total size: [" + coList.size() + "], thread start: [" + startPosition + "], thread stop: [" + stopPosition + "]");
while (coList.hasNext()) {
// now we fetch pairs into batch
List<Pair<Pair<T, T>, Double>> pairs = new ArrayList<>();
int cnt = 0;
while (coList.hasNext() && cnt < batchSize) {
pairs.add(coList.next());
cnt++;
}
if (shuffle)
Collections.shuffle(pairs);
Iterator<Pair<Pair<T, T>, Double>> iterator = pairs.iterator();
while (iterator.hasNext()) {
// now for each pair do appropriate training
Pair<Pair<T, T>, Double> pairDoublePair = iterator.next();
// That's probably ugly and probably should be improved somehow
T element1 = pairDoublePair.getFirst().getFirst();
T element2 = pairDoublePair.getFirst().getSecond();
double weight = pairDoublePair.getSecond(); //coOccurrences.getCoOccurrenceCount(element1, element2);
if (weight <= 0) {
// log.warn("Skipping pair ("+ element1.getLabel()+", " + element2.getLabel()+")");
pairsCounter.incrementAndGet();
continue;
}
errorCounter.incrementCount(epochId, iterateSample(element1, element2, weight));
if (pairsCounter.incrementAndGet() % 1000000 == 0) {
log.info("Processed [" + pairsCounter.get() + "] word pairs so far...");
}
}
}
}
}
public static class Builder<T extends SequenceElement> {
protected double xMax = 100.0d;
protected double alpha = 0.75d;
protected double learningRate = 0.0d;
protected boolean shuffle = false;
protected boolean symmetric = false;
protected int maxmemory = 0;
protected int batchSize = 1000;
public Builder() {
}
/**
* This parameter specifies batch size for each thread. Also, if shuffle == TRUE, this batch will be shuffled before processing. Default value: 1000;
*
* @param batchSize
* @return
*/
public Builder<T> batchSize(int batchSize) {
this.batchSize = batchSize;
return this;
}
/**
* Initial learning rate; default 0.05
*
* @param eta
* @return
*/
public Builder<T> learningRate(double eta) {
this.learningRate = eta;
return this;
}
/**
* Parameter in exponent of weighting function; default 0.75
*
* @param alpha
* @return
*/
public Builder<T> alpha(double alpha) {
this.alpha = alpha;
return this;
}
/**
* This method allows you to specify maximum memory available for CoOccurrence map builder.
*
* Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm.
* Please note: this option won't override -Xmx JVM value.
*
* @param gbytes memory limit, in gigabytes
* @return
*/
public Builder<T> maxMemory(int gbytes) {
this.maxmemory = gbytes;
return this;
}
/**
* Parameter specifying cutoff in weighting function; default 100.0
*
* @param xMax
* @return
*/
public Builder<T> xMax(double xMax) {
this.xMax = xMax;
return this;
}
/**
* Parameter specifying, if cooccurrences list should be shuffled between training epochs
*
* @param reallyShuffle
* @return
*/
public Builder<T> shuffle(boolean reallyShuffle) {
this.shuffle = reallyShuffle;
return this;
}
/**
* Parameters specifying, if cooccurrences list should be build into both directions from any current word.
*
* @param reallySymmetric
* @return
*/
public Builder<T> symmetric(boolean reallySymmetric) {
this.symmetric = reallySymmetric;
return this;
}
public GloVe<T> build() {
GloVe<T> ret = new GloVe<>();
ret.symmetric = this.symmetric;
ret.shuffle = this.shuffle;
ret.xMax = this.xMax;
ret.alpha = this.alpha;
ret.learningRate = this.learningRate;
ret.maxmemory = this.maxmemory;
ret.batchSize = this.batchSize;
return ret;
}
}
}

View File

@ -64,7 +64,6 @@ import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.fasttext.FastText;
import org.deeplearning4j.models.glove.Glove;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceElementFactory;
@ -142,10 +141,6 @@ import lombok.val;
* {@link #readParagraphVectors(String)}
* {@link #readParagraphVectors(InputStream)}
*
* <li>Serializers for GloVe:</li>
* {@link #writeWordVectors(Glove, File)}
* {@link #writeWordVectors(Glove, String)}
* {@link #writeWordVectors(Glove, OutputStream)}
*
* <li>Adapters</li>
* {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
@ -154,7 +149,6 @@ import lombok.val;
* {@link #loadTxt(InputStream)}
*
* <li>Serializers to tSNE format</li>
* {@link #writeTsneFormat(Glove, INDArray, File)}
* {@link #writeTsneFormat(Word2Vec, INDArray, File)}
*
* <li>FastText serializer:</li>
@ -974,7 +968,7 @@ public class WordVectorSerializer {
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
// first we load syn0
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors));
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors)); //Note stream is closed in loadTxt
InMemoryLookupTable lookupTable = pair.getFirst();
lookupTable.setNegative(configuration.getNegative());
if (configuration.getNegative() > 0)
@ -1161,48 +1155,6 @@ public class WordVectorSerializer {
}
}
/**
* This method saves GloVe model to the given output stream.
*
* @param vectors GloVe model to be saved
* @param file path where model should be saved to
*/
public static void writeWordVectors(@NonNull Glove vectors, @NonNull File file) {
try (BufferedOutputStream fos = new BufferedOutputStream(new FileOutputStream(file))) {
writeWordVectors(vectors, fos);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* This method saves GloVe model to the given output stream.
*
* @param vectors GloVe model to be saved
* @param path path where model should be saved to
*/
public static void writeWordVectors(@NonNull Glove vectors, @NonNull String path) {
try (BufferedOutputStream fos = new BufferedOutputStream(new FileOutputStream(path))) {
writeWordVectors(vectors, fos);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* This method saves GloVe model to the given OutputStream
*
* @param vectors GloVe model to be saved
* @param stream OutputStream where model should be saved to
*/
public static void writeWordVectors(@NonNull Glove vectors, @NonNull OutputStream stream) {
try {
writeWordVectors(vectors.lookupTable(), stream);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* This method saves paragraph vectors to the given output stream.
*
@ -1655,7 +1607,7 @@ public class WordVectorSerializer {
*/
@Deprecated
public static WordVectors loadTxtVectors(File vectorsFile) throws IOException {
FileInputStream fileInputStream = new FileInputStream(vectorsFile);
FileInputStream fileInputStream = new FileInputStream(vectorsFile); //Note stream is closed in loadTxt
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(fileInputStream);
return fromPair(pair);
}
@ -1877,43 +1829,6 @@ public class WordVectorSerializer {
return fromPair(Pair.makePair((InMemoryLookupTable) lookupTable, (VocabCache) cache));
}
/**
* Write the tsne format
*
* @param vec the word vectors to use for labeling
* @param tsne the tsne array to write
* @param csv the file to use
* @throws Exception
*/
public static void writeTsneFormat(Glove vec, INDArray tsne, File csv) throws Exception {
try (BufferedWriter write = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(csv), StandardCharsets.UTF_8))) {
int words = 0;
InMemoryLookupCache l = (InMemoryLookupCache) vec.vocab();
for (String word : vec.vocab().words()) {
if (word == null) {
continue;
}
StringBuilder sb = new StringBuilder();
INDArray wordVector = tsne.getRow(l.wordFor(word).getIndex());
for (int j = 0; j < wordVector.length(); j++) {
sb.append(wordVector.getDouble(j));
if (j < wordVector.length() - 1) {
sb.append(",");
}
}
sb.append(",");
sb.append(word.replaceAll(" ", WHITESPACE_REPLACEMENT));
sb.append(" ");
sb.append("\n");
write.write(sb.toString());
}
log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize());
}
}
/**
* Write the tsne format
*

View File

@ -1,652 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove;
import lombok.NonNull;
import org.deeplearning4j.models.glove.count.*;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.iterators.FilteredSequenceIterator;
import org.deeplearning4j.models.sequencevectors.iterators.SynchronizedSequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator;
import org.deeplearning4j.common.util.DL4JFileUtils;
import org.nd4j.common.util.ThreadUtils;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
* This class implements building cooccurrence map for abstract training corpus.
* However it's performance rather low, due to exsessive IO that happens in ShadowCopyThread
*
* PLEASE NOTE: Current implementation involves massive IO, and it should be rewritter as soon as ND4j gets sparse arrays support
*
* @author raver119@gmail.com
*/
public class AbstractCoOccurrences<T extends SequenceElement> implements Serializable {
protected boolean symmetric;
protected int windowSize;
protected VocabCache<T> vocabCache;
protected SequenceIterator<T> sequenceIterator;
// please note, we need enough room for ShadowCopy thread, that's why -1 there
protected int workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
// target file, where text with cooccurrencies should be saved
protected File targetFile;
protected ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
protected long memory_threshold = 0;
private ShadowCopyThread shadowThread;
// private Counter<Integer> sentenceOccurrences = Util.parallelCounter();
//private CounterMap<T, T> coOccurrenceCounts = Util.parallelCounterMap();
private volatile CountMap<T> coOccurrenceCounts = new CountMap<>();
//private Counter<Integer> occurrenceAllocations = Util.parallelCounter();
//private List<Pair<T, T>> coOccurrences;
private AtomicLong processedSequences = new AtomicLong(0);
protected static final Logger logger = LoggerFactory.getLogger(AbstractCoOccurrences.class);
// this method should be private, to avoid non-configured instantiation
private AbstractCoOccurrences() {}
/**
* This method returns cooccurrence distance weights for two SequenceElements
*
* @param element1
* @param element2
* @return distance weight
*/
public double getCoOccurrenceCount(@NonNull T element1, @NonNull T element2) {
return coOccurrenceCounts.getCount(element1, element2);
}
/**
* This method returns estimated memory footrpint, based on current CountMap content
* @return
*/
protected long getMemoryFootprint() {
// TODO: implement this method. It should return approx. memory used by appropriate CountMap
try {
lock.readLock().lock();
return ((long) coOccurrenceCounts.size()) * 24L * 5L;
} finally {
lock.readLock().unlock();
}
}
/**
* This memory returns memory threshold, defined as 1/2 of memory allowed for allocation
* @return
*/
protected long getMemoryThreshold() {
return memory_threshold / 2L;
}
public void fit() {
shadowThread = new ShadowCopyThread();
shadowThread.start();
// we should reset iterator before counting cooccurrences
sequenceIterator.reset();
List<CoOccurrencesCalculatorThread> threads = new ArrayList<>();
for (int x = 0; x < workers; x++) {
threads.add(x, new CoOccurrencesCalculatorThread(x, new FilteredSequenceIterator<>(
new SynchronizedSequenceIterator<>(sequenceIterator), vocabCache), processedSequences));
threads.get(x).start();
}
for (int x = 0; x < workers; x++) {
try {
threads.get(x).join();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
shadowThread.finish();
logger.info("CoOccurrences map was built.");
}
/**
*
* This method returns iterator with elements pairs and their weights. Resulting iterator is safe to use in multi-threaded environment.
*
* Developer's note: thread safety on received iterator is delegated to PrefetchedSentenceIterator
* @return
*/
public Iterator<Pair<Pair<T, T>, Double>> iterator() {
final SentenceIterator iterator;
try {
iterator = new SynchronizedSentenceIterator(
new PrefetchingSentenceIterator.Builder(new BasicLineIterator(targetFile))
.setFetchSize(500000).build());
} catch (Exception e) {
logger.error("Target file was not found on last stage!");
throw new RuntimeException(e);
}
return new Iterator<Pair<Pair<T, T>, Double>>() {
/*
iterator should be built on top of current text file with all pairs
*/
@Override
public boolean hasNext() {
return iterator.hasNext();
}
@Override
public Pair<Pair<T, T>, Double> next() {
String line = iterator.nextSentence();
String[] strings = line.split(" ");
T element1 = vocabCache.elementAtIndex(Integer.valueOf(strings[0]));
T element2 = vocabCache.elementAtIndex(Integer.valueOf(strings[1]));
Double weight = Double.valueOf(strings[2]);
return new Pair<>(new Pair<>(element1, element2), weight);
}
@Override
public void remove() {
throw new UnsupportedOperationException("remove() method can't be supported on read-only interface");
}
};
}
public static class Builder<T extends SequenceElement> {
protected boolean symmetric;
protected int windowSize = 5;
protected VocabCache<T> vocabCache;
protected SequenceIterator<T> sequenceIterator;
protected int workers = Runtime.getRuntime().availableProcessors();
protected File target;
protected long maxmemory = Runtime.getRuntime().maxMemory();
public Builder() {
}
public Builder<T> symmetric(boolean reallySymmetric) {
this.symmetric = reallySymmetric;
return this;
}
public Builder<T> windowSize(int windowSize) {
this.windowSize = windowSize;
return this;
}
public Builder<T> vocabCache(@NonNull VocabCache<T> cache) {
this.vocabCache = cache;
return this;
}
public Builder<T> iterate(@NonNull SequenceIterator<T> iterator) {
this.sequenceIterator = new SynchronizedSequenceIterator<>(iterator);
return this;
}
public Builder<T> workers(int numWorkers) {
this.workers = numWorkers;
return this;
}
/**
* This method allows you to specify maximum memory available for CoOccurrence map builder.
*
* Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm.
* Please note: this option won't override -Xmx JVM value.
*
* @param gbytes memory available, in GigaBytes
* @return
*/
public Builder<T> maxMemory(int gbytes) {
if (gbytes > 0) {
this.maxmemory = Math.max(gbytes - 1, 1) * 1024 * 1024 * 1024L;
}
return this;
}
/**
* Path to save cooccurrence map after construction.
* If targetFile is not specified, temporary file will be used.
*
* @param path
* @return
*/
public Builder<T> targetFile(@NonNull String path) {
this.targetFile(new File(path));
return this;
}
/**
* Path to save cooccurrence map after construction.
* If targetFile is not specified, temporary file will be used.
*
* @param file
* @return
*/
public Builder<T> targetFile(@NonNull File file) {
this.target = file;
return this;
}
public AbstractCoOccurrences<T> build() {
AbstractCoOccurrences<T> ret = new AbstractCoOccurrences<>();
ret.sequenceIterator = this.sequenceIterator;
ret.windowSize = this.windowSize;
ret.vocabCache = this.vocabCache;
ret.symmetric = this.symmetric;
ret.workers = this.workers;
if (this.maxmemory < 1) {
this.maxmemory = Runtime.getRuntime().maxMemory();
}
ret.memory_threshold = this.maxmemory;
logger.info("Actual memory limit: [" + this.maxmemory + "]");
// use temp file, if no target file was specified
try {
if (this.target == null) {
this.target = DL4JFileUtils.createTempFile("cooccurrence", "map");
}
this.target.deleteOnExit();
} catch (Exception e) {
throw new RuntimeException(e);
}
ret.targetFile = this.target;
return ret;
}
}
private class CoOccurrencesCalculatorThread extends Thread implements Runnable {
private final SequenceIterator<T> iterator;
private final AtomicLong sequenceCounter;
private int threadId;
public CoOccurrencesCalculatorThread(int threadId, @NonNull SequenceIterator<T> iterator,
@NonNull AtomicLong sequenceCounter) {
this.iterator = iterator;
this.sequenceCounter = sequenceCounter;
this.threadId = threadId;
this.setName("CoOccurrencesCalculatorThread " + threadId);
}
@Override
public void run() {
while (iterator.hasMoreSequences()) {
Sequence<T> sequence = iterator.nextSequence();
List<String> tokens = new ArrayList<>(sequence.asLabels());
// logger.info("Tokens size: " + tokens.size());
for (int x = 0; x < sequence.getElements().size(); x++) {
int wordIdx = vocabCache.indexOf(tokens.get(x));
if (wordIdx < 0) {
continue;
}
String w1 = vocabCache.wordFor(tokens.get(x)).getLabel();
// THIS iS SAFE TO REMOVE, NO CHANCE WE'll HAVE UNK WORD INSIDE SEQUENCE
/*if(w1.equals(Glove.UNK))
continue;
*/
int windowStop = Math.min(x + windowSize + 1, tokens.size());
for (int j = x; j < windowStop; j++) {
int otherWord = vocabCache.indexOf(tokens.get(j));
if (otherWord < 0) {
continue;
}
String w2 = vocabCache.wordFor(tokens.get(j)).getLabel();
if (w2.equals(Glove.DEFAULT_UNK) || otherWord == wordIdx) {
continue;
}
T tokenX = vocabCache.wordFor(tokens.get(x));
T tokenJ = vocabCache.wordFor(tokens.get(j));
double nWeight = 1.0 / (j - x + Nd4j.EPS_THRESHOLD);
while (getMemoryFootprint() >= getMemoryThreshold()) {
shadowThread.invoke();
/*lock.readLock().lock();
int size = coOccurrenceCounts.size();
lock.readLock().unlock();
*/
if (threadId == 0) {
logger.debug("Memory consuimption > threshold: {footrpint: [" + getMemoryFootprint()
+ "], threshold: [" + getMemoryThreshold() + "] }");
}
ThreadUtils.uncheckedSleep(10000);
}
/*
if (getMemoryFootprint() == 0) {
logger.info("Zero size!");
}
*/
try {
lock.readLock().lock();
if (wordIdx < otherWord) {
coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight);
if (symmetric) {
coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight);
}
} else {
coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight);
if (symmetric) {
coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight);
}
}
} finally {
lock.readLock().unlock();
}
}
}
sequenceCounter.incrementAndGet();
}
}
}
/**
* This class is designed to provide shadow copy functionality for CoOccurence maps, since with proper corpus size you can't fit such a map into memory
*
*/
private class ShadowCopyThread extends Thread implements Runnable {
private AtomicBoolean isFinished = new AtomicBoolean(false);
private AtomicBoolean isTerminate = new AtomicBoolean(false);
private AtomicBoolean isInvoked = new AtomicBoolean(false);
private AtomicBoolean shouldInvoke = new AtomicBoolean(false);
// file that contains resuts from previous runs
private File[] tempFiles;
private RoundCount counter;
public ShadowCopyThread() {
try {
counter = new RoundCount(1);
tempFiles = new File[2];
tempFiles[0] = DL4JFileUtils.createTempFile("aco", "tmp");
tempFiles[1] = DL4JFileUtils.createTempFile("aco", "tmp");
tempFiles[0].deleteOnExit();
tempFiles[1].deleteOnExit();
} catch (Exception e) {
throw new RuntimeException(e);
}
this.setName("ACO ShadowCopy thread");
}
@Override
public void run() {
/*
Basic idea is pretty simple: run quetly, untill memory gets filled up to some high volume.
As soon as this happens - execute shadow copy.
*/
while (!isFinished.get() && !isTerminate.get()) {
// check used memory. if memory use below threshold - sleep for a while. if above threshold - invoke copier
if (getMemoryFootprint() > getMemoryThreshold() || (shouldInvoke.get() && !isInvoked.get())) {
// we'll just invoke copier, nothing else
shouldInvoke.compareAndSet(true, false);
invokeBlocking();
} else {
/*
commented and left here for future debugging purposes, if needed
//lock.readLock().lock();
//int size = coOccurrenceCounts.size();
//lock.readLock().unlock();
//logger.info("Current memory situation: {size: [" +size+ "], footprint: [" + getMemoryFootprint()+"], threshold: ["+ getMemoryThreshold() +"]}");
*/
ThreadUtils.uncheckedSleep(1000);
}
}
}
/**
* This methods advises shadow copy process to start
*/
public void invoke() {
shouldInvoke.compareAndSet(false, true);
}
/**
* This methods dumps cooccurrence map into save file.
* Please note: this method is synchronized and will block, until complete
*/
public synchronized void invokeBlocking() {
if (getMemoryFootprint() < getMemoryThreshold() && !isFinished.get()) {
return;
}
int numberOfLinesSaved = 0;
isInvoked.set(true);
logger.debug("Memory purge started.");
/*
Basic plan:
1. Open temp file
2. Read that file line by line
3. For each read line do synchronization in memory > new file direction
*/
counter.tick();
CountMap<T> localMap;
try {
// in any given moment there's going to be only 1 WriteLock, due to invokeBlocking() being synchronized call
lock.writeLock().lock();
// obtain local copy of CountMap
localMap = coOccurrenceCounts;
// set new CountMap, and release write lock
coOccurrenceCounts = new CountMap<>();
} finally {
lock.writeLock().unlock();
}
try {
File file = null;
if (!isFinished.get()) {
file = tempFiles[counter.previous()];
} else
file = targetFile;
// PrintWriter pw = new PrintWriter(file);
int linesRead = 0;
logger.debug("Saving to: [" + counter.get() + "], Reading from: [" + counter.previous() + "]");
CoOccurenceReader<T> reader =
new BinaryCoOccurrenceReader<>(tempFiles[counter.previous()], vocabCache, localMap);
CoOccurrenceWriter<T> writer = (isFinished.get()) ? new ASCIICoOccurrenceWriter<T>(targetFile)
: new BinaryCoOccurrenceWriter<T>(tempFiles[counter.get()]);
while (reader.hasMoreObjects()) {
CoOccurrenceWeight<T> line = reader.nextObject();
if (line != null) {
writer.writeObject(line);
numberOfLinesSaved++;
linesRead++;
}
}
reader.finish();
logger.debug("Lines read: [" + linesRead + "]");
//now, we can dump the rest of elements, which were not presented in existing dump
Iterator<Pair<T, T>> iterator = localMap.getPairIterator();
while (iterator.hasNext()) {
Pair<T, T> pair = iterator.next();
double mWeight = localMap.getCount(pair);
CoOccurrenceWeight<T> object = new CoOccurrenceWeight<>();
object.setElement1(pair.getFirst());
object.setElement2(pair.getSecond());
object.setWeight(mWeight);
writer.writeObject(object);
numberOfLinesSaved++;
// if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
}
writer.finish();
/*
SentenceIterator sIterator = new PrefetchingSentenceIterator.Builder(new BasicLineIterator(tempFiles[counter.get()]))
.setFetchSize(500000)
.build();
int linesRead = 0;
while (sIterator.hasNext()) {
//List<Writable> list = new ArrayList<>(reader.next());
String sentence = sIterator.nextSentence();
if (sentence == null || sentence.isEmpty()) continue;
String[] strings = sentence.split(" ");
// first two elements are integers - vocab indexes
//T element1 = vocabCache.wordFor(vocabCache.wordAtIndex(list.get(0).toInt()));
//T element2 = vocabCache.wordFor(vocabCache.wordAtIndex(list.get(1).toInt()));
T element1 = vocabCache.elementAtIndex(Integer.valueOf(strings[0]));
T element2 = vocabCache.elementAtIndex(Integer.valueOf(strings[1]));
// getting third element, previously stored weight
double sWeight = Double.valueOf(strings[2]); // list.get(2).toDouble();
// now, since we have both elements ready, we can check this pair against inmemory map
double mWeight = localMap.getCount(element1, element2);
if (mWeight <= 0) {
// this means we have no such pair in memory, so we'll do nothing to sWeight
} else {
// since we have new weight value in memory, we should update sWeight value before moving it off memory
sWeight += mWeight;
// original pair can be safely removed from CountMap
localMap.removePair(element1,element2);
}
StringBuilder builder = new StringBuilder().append(element1.getIndex()).append(" ").append(element2.getIndex()).append(" ").append(sWeight);
pw.println(builder.toString());
numberOfLinesSaved++;
linesRead++;
// if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
// if (linesRead % 100000 == 0) logger.info("Lines read: [" + linesRead +"]");
}
*/
/*
logger.info("Lines read: [" + linesRead + "]");
//now, we can dump the rest of elements, which were not presented in existing dump
Iterator<Pair<T, T>> iterator = localMap.getPairIterator();
while (iterator.hasNext()) {
Pair<T, T> pair = iterator.next();
double mWeight = localMap.getCount(pair);
StringBuilder builder = new StringBuilder().append(pair.getFirst().getIndex()).append(" ").append(pair.getFirst().getIndex()).append(" ").append(mWeight);
pw.println(builder.toString());
numberOfLinesSaved++;
// if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
}
pw.flush();
pw.close();
*/
// just a hint for gc
localMap = null;
//sIterator.finish();
} catch (Exception e) {
throw new RuntimeException(e);
}
logger.info("Number of word pairs saved so far: [" + numberOfLinesSaved + "]");
isInvoked.set(false);
}
/**
* This method provides soft finish ability for shadow copy process.
* Please note: it's blocking call, since it requires for final merge.
*/
public void finish() {
if (this.isFinished.get()) {
return;
}
this.isFinished.set(true);
invokeBlocking();
}
/**
* This method provides hard fiinish ability for shadow copy process
*/
public void terminate() {
this.isTerminate.set(true);
}
}
}

View File

@ -1,444 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener;
import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.StreamLineIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import java.util.Collection;
import java.util.List;
/**
* GlobalVectors standalone implementation for DL4j.
* Based on original Stanford GloVe <a href="http://www-nlp.stanford.edu/pubs/glove.pdf">http://www-nlp.stanford.edu/pubs/glove.pdf</a>
*
* @author raver119@gmail.com
*/
public class Glove extends SequenceVectors<VocabWord> {
protected Glove() {
}
public static class Builder extends SequenceVectors.Builder<VocabWord> {
private double xMax;
private boolean shuffle;
private boolean symmetric;
protected double alpha = 0.75d;
private int maxmemory = (int) (Runtime.getRuntime().totalMemory() / 1024 / 1024 / 1024);
protected TokenizerFactory tokenFactory;
protected SentenceIterator sentenceIterator;
protected DocumentIterator documentIterator;
public Builder() {
super();
}
public Builder(@NonNull VectorsConfiguration configuration) {
super(configuration);
}
/**
* This method has no effect for GloVe
*
* @param vec existing WordVectors model
* @return
*/
@Override
public Builder useExistingWordVectors(@NonNull WordVectors vec) {
return this;
}
@Override
public Builder iterate(@NonNull SequenceIterator<VocabWord> iterator) {
super.iterate(iterator);
return this;
}
/**
* Specifies minibatch size for training process.
*
* @param batchSize
* @return
*/
@Override
public Builder batchSize(int batchSize) {
super.batchSize(batchSize);
return this;
}
/**
* Ierations and epochs are the same in GloVe implementation.
*
* @param iterations
* @return
*/
@Override
public Builder iterations(int iterations) {
super.epochs(iterations);
return this;
}
/**
* Sets the number of iteration over training corpus during training
*
* @param numEpochs
* @return
*/
@Override
public Builder epochs(int numEpochs) {
super.epochs(numEpochs);
return this;
}
@Override
public Builder useAdaGrad(boolean reallyUse) {
super.useAdaGrad(true);
return this;
}
@Override
public Builder layerSize(int layerSize) {
super.layerSize(layerSize);
return this;
}
@Override
public Builder learningRate(double learningRate) {
super.learningRate(learningRate);
return this;
}
/**
* Sets minimum word frequency during vocabulary mastering.
* Please note: this option is ignored, if vocabulary is built outside of GloVe
*
* @param minWordFrequency
* @return
*/
@Override
public Builder minWordFrequency(int minWordFrequency) {
super.minWordFrequency(minWordFrequency);
return this;
}
@Override
public Builder minLearningRate(double minLearningRate) {
super.minLearningRate(minLearningRate);
return this;
}
@Override
public Builder resetModel(boolean reallyReset) {
super.resetModel(reallyReset);
return this;
}
@Override
public Builder vocabCache(@NonNull VocabCache<VocabWord> vocabCache) {
super.vocabCache(vocabCache);
return this;
}
@Override
public Builder lookupTable(@NonNull WeightLookupTable<VocabWord> lookupTable) {
super.lookupTable(lookupTable);
return this;
}
@Override
@Deprecated
public Builder sampling(double sampling) {
super.sampling(sampling);
return this;
}
@Override
@Deprecated
public Builder negativeSample(double negative) {
super.negativeSample(negative);
return this;
}
@Override
public Builder stopWords(@NonNull List<String> stopList) {
super.stopWords(stopList);
return this;
}
@Override
public Builder trainElementsRepresentation(boolean trainElements) {
super.trainElementsRepresentation(true);
return this;
}
@Override
@Deprecated
public Builder trainSequencesRepresentation(boolean trainSequences) {
super.trainSequencesRepresentation(false);
return this;
}
@Override
public Builder stopWords(@NonNull Collection<VocabWord> stopList) {
super.stopWords(stopList);
return this;
}
@Override
public Builder windowSize(int windowSize) {
super.windowSize(windowSize);
return this;
}
@Override
public Builder seed(long randomSeed) {
super.seed(randomSeed);
return this;
}
@Override
public Builder workers(int numWorkers) {
super.workers(numWorkers);
return this;
}
/**
* Sets TokenizerFactory to be used for training
*
* @param tokenizerFactory
* @return
*/
public Builder tokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
this.tokenFactory = tokenizerFactory;
return this;
}
/**
* Parameter specifying cutoff in weighting function; default 100.0
*
* @param xMax
* @return
*/
public Builder xMax(double xMax) {
this.xMax = xMax;
return this;
}
/**
* Parameters specifying, if cooccurrences list should be build into both directions from any current word.
*
* @param reallySymmetric
* @return
*/
public Builder symmetric(boolean reallySymmetric) {
this.symmetric = reallySymmetric;
return this;
}
/**
* Parameter specifying, if cooccurrences list should be shuffled between training epochs
*
* @param reallyShuffle
* @return
*/
public Builder shuffle(boolean reallyShuffle) {
this.shuffle = reallyShuffle;
return this;
}
/**
* This method has no effect for ParagraphVectors
*
* @param windows
* @return
*/
@Override
public Builder useVariableWindow(int... windows) {
// no-op
return this;
}
/**
* Parameter in exponent of weighting function; default 0.75
*
* @param alpha
* @return
*/
public Builder alpha(double alpha) {
this.alpha = alpha;
return this;
}
public Builder iterate(@NonNull SentenceIterator iterator) {
this.sentenceIterator = iterator;
return this;
}
public Builder iterate(@NonNull DocumentIterator iterator) {
this.sentenceIterator = new StreamLineIterator.Builder(iterator).setFetchSize(100).build();
return this;
}
/**
* Sets ModelUtils that gonna be used as provider for utility methods: similarity(), wordsNearest(), accuracy(), etc
*
* @param modelUtils model utils to be used
* @return
*/
@Override
public Builder modelUtils(@NonNull ModelUtils<VocabWord> modelUtils) {
super.modelUtils(modelUtils);
return this;
}
/**
* This method sets VectorsListeners for this SequenceVectors model
*
* @param vectorsListeners
* @return
*/
@Override
public Builder setVectorsListeners(@NonNull Collection<VectorsListener<VocabWord>> vectorsListeners) {
super.setVectorsListeners(vectorsListeners);
return this;
}
/**
* This method allows you to specify maximum memory available for CoOccurrence map builder.
*
* Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm.
* Please note: this option won't override -Xmx JVM value.
*
* @param gbytes memory limit, in gigabytes
* @return
*/
public Builder maxMemory(int gbytes) {
this.maxmemory = gbytes;
return this;
}
/**
* This method allows you to specify SequenceElement that will be used as UNK element, if UNK is used
*
* @param element
* @return
*/
@Override
public Builder unknownElement(VocabWord element) {
super.unknownElement(element);
return this;
}
/**
* This method allows you to specify, if UNK word should be used internally
*
* @param reallyUse
* @return
*/
@Override
public Builder useUnknown(boolean reallyUse) {
super.useUnknown(reallyUse);
if (this.unknownElement == null) {
this.unknownElement(new VocabWord(1.0, Glove.DEFAULT_UNK));
}
return this;
}
public Glove build() {
presetTables();
Glove ret = new Glove();
// hardcoded value for glove
if (sentenceIterator != null) {
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(sentenceIterator)
.tokenizerFactory(tokenFactory).build();
this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build();
}
ret.trainElementsVectors = true;
ret.trainSequenceVectors = false;
ret.useAdeGrad = true;
this.useAdaGrad = true;
ret.learningRate.set(this.learningRate);
ret.resetModel = this.resetModel;
ret.batchSize = this.batchSize;
ret.iterator = this.iterator;
ret.numEpochs = this.numEpochs;
ret.numIterations = this.iterations;
ret.layerSize = this.layerSize;
ret.useUnknown = this.useUnknown;
ret.unknownElement = this.unknownElement;
this.configuration.setLearningRate(this.learningRate);
this.configuration.setLayersSize(layerSize);
this.configuration.setHugeModelExpected(hugeModelExpected);
this.configuration.setWindow(window);
this.configuration.setMinWordFrequency(minWordFrequency);
this.configuration.setIterations(iterations);
this.configuration.setSeed(seed);
this.configuration.setBatchSize(batchSize);
this.configuration.setLearningRateDecayWords(learningRateDecayWords);
this.configuration.setMinLearningRate(minLearningRate);
this.configuration.setSampling(this.sampling);
this.configuration.setUseAdaGrad(useAdaGrad);
this.configuration.setNegative(negative);
this.configuration.setEpochs(this.numEpochs);
ret.configuration = this.configuration;
ret.lookupTable = this.lookupTable;
ret.vocab = this.vocabCache;
ret.modelUtils = this.modelUtils;
ret.eventListeners = this.vectorsListeners;
ret.elementsLearningAlgorithm = new GloVe.Builder<VocabWord>().learningRate(this.learningRate)
.shuffle(this.shuffle).symmetric(this.symmetric).xMax(this.xMax).alpha(this.alpha)
.maxMemory(maxmemory).build();
return ret;
}
}
}

View File

@ -1,334 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.legacy.AdaGrad;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
/**
* Glove lookup table
*
* @author Adam Gibson
*/
// Deprecated due to logic being pulled off WeightLookupTable classes into LearningAlgorithm interfaces for better code.
@Deprecated
public class GloveWeightLookupTable<T extends SequenceElement> extends InMemoryLookupTable<T> {
private AdaGrad weightAdaGrad;
private AdaGrad biasAdaGrad;
private INDArray bias;
//also known as alpha
private double xMax = 0.75;
private double maxCount = 100;
public GloveWeightLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, Random gen,
double negative, double xMax, double maxCount) {
super(vocab, vectorLength, useAdaGrad, lr, gen, negative);
this.xMax = xMax;
this.maxCount = maxCount;
}
@Override
public void resetWeights(boolean reset) {
if (rng == null)
this.rng = Nd4j.getRandom();
//note the +2 which is the unk vocab word and the bias
if (syn0 == null || reset) {
syn0 = Nd4j.rand(new int[] {vocab.numWords() + 1, vectorLength}, rng).subi(0.5).divi((double) vectorLength);
INDArray randUnk = Nd4j.rand(1, vectorLength, rng).subi(0.5).divi(vectorLength);
putVector(Word2Vec.DEFAULT_UNK, randUnk);
}
if (weightAdaGrad == null || reset) {
weightAdaGrad = new AdaGrad(new long[]{vocab.numWords() + 1, vectorLength}, lr.get());
}
//right after unknown
if (bias == null || reset)
bias = Nd4j.create(syn0.rows());
if (biasAdaGrad == null || reset) {
biasAdaGrad = new AdaGrad(bias.shape(), lr.get());
}
}
/**
* Reset the weights of the cache
*/
@Override
public void resetWeights() {
resetWeights(true);
}
/**
* glove iteration
* @param w1 the first word
* @param w2 the second word
* @param score the weight learned for the particular co occurrences
*/
public double iterateSample(T w1, T w2, double score) {
INDArray w1Vector = syn0.slice(w1.getIndex());
INDArray w2Vector = syn0.slice(w2.getIndex());
//prediction: input + bias
if (w1.getIndex() < 0 || w1.getIndex() >= syn0.rows())
throw new IllegalArgumentException("Illegal index for word " + w1.getLabel());
if (w2.getIndex() < 0 || w2.getIndex() >= syn0.rows())
throw new IllegalArgumentException("Illegal index for word " + w2.getLabel());
//w1 * w2 + bias
double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
double weight = Math.pow(Math.min(1.0, (score / maxCount)), xMax);
double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
if (Double.isNaN(fDiff))
fDiff = Nd4j.EPS_THRESHOLD;
//amount of change
double gradient = fDiff;
//note the update step here: the gradient is
//the gradient of the OPPOSITE word
//for adagrad we will use the index of the word passed in
//for the gradient calculation we will use the context vector
update(w1, w1Vector, w2Vector, gradient);
update(w2, w2Vector, w1Vector, gradient);
return fDiff;
}
private void update(T w1, INDArray wordVector, INDArray contextVector, double gradient) {
//gradient for word vectors
INDArray grad1 = contextVector.mul(gradient);
INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape());
//update vector
wordVector.subi(update);
double w1Bias = bias.getDouble(w1.getIndex());
double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape());
double update2 = w1Bias - biasGradient;
bias.putScalar(w1.getIndex(), update2);
}
public AdaGrad getWeightAdaGrad() {
return weightAdaGrad;
}
public AdaGrad getBiasAdaGrad() {
return biasAdaGrad;
}
/**
* Load a glove model from an input stream.
* The format is:
* word num1 num2....
* @param is the input stream to read from for the weights
* @param vocab the vocab for the lookuptable
* @return the loaded model
* @throws java.io.IOException if one occurs
*/
public static GloveWeightLookupTable load(InputStream is, VocabCache<? extends SequenceElement> vocab)
throws IOException {
LineIterator iter = IOUtils.lineIterator(is, "UTF-8");
GloveWeightLookupTable glove = null;
Map<String, float[]> wordVectors = new HashMap<>();
while (iter.hasNext()) {
String line = iter.nextLine().trim();
if (line.isEmpty())
continue;
String[] split = line.split(" ");
String word = split[0];
if (glove == null)
glove = new GloveWeightLookupTable.Builder().cache(vocab).vectorLength(split.length - 1).build();
if (word.isEmpty())
continue;
float[] read = read(split, glove.layerSize());
if (read.length < 1)
continue;
wordVectors.put(word, read);
}
glove.setSyn0(weights(glove, wordVectors, vocab));
glove.resetWeights(false);
iter.close();
return glove;
}
private static INDArray weights(GloveWeightLookupTable glove, Map<String, float[]> data, VocabCache vocab) {
INDArray ret = Nd4j.create(data.size(), glove.layerSize());
for (Map.Entry<String, float[]> entry : data.entrySet()) {
String key = entry.getKey();
INDArray row = Nd4j.create(Nd4j.createBuffer(entry.getValue()));
if (row.length() != glove.layerSize())
continue;
if (vocab.indexOf(key) >= data.size())
continue;
if (vocab.indexOf(key) < 0)
continue;
ret.putRow(vocab.indexOf(key), row);
}
return ret;
}
private static float[] read(String[] split, int length) {
float[] ret = new float[length];
for (int i = 1; i < split.length; i++) {
ret[i - 1] = Float.parseFloat(split[i]);
}
return ret;
}
@Override
public void iterateSample(T w1, T w2, AtomicLong nextRandom, double alpha) {
throw new UnsupportedOperationException();
}
public double getxMax() {
return xMax;
}
public void setxMax(double xMax) {
this.xMax = xMax;
}
public double getMaxCount() {
return maxCount;
}
public void setMaxCount(double maxCount) {
this.maxCount = maxCount;
}
public INDArray getBias() {
return bias;
}
public void setBias(INDArray bias) {
this.bias = bias;
}
public static class Builder<T extends SequenceElement> extends InMemoryLookupTable.Builder<T> {
private double xMax = 0.75;
private double maxCount = 100;
public Builder<T> maxCount(double maxCount) {
this.maxCount = maxCount;
return this;
}
public Builder<T> xMax(double xMax) {
this.xMax = xMax;
return this;
}
@Override
public Builder<T> cache(VocabCache<T> vocab) {
super.cache(vocab);
return this;
}
@Override
public Builder<T> negative(double negative) {
super.negative(negative);
return this;
}
@Override
public Builder<T> vectorLength(int vectorLength) {
super.vectorLength(vectorLength);
return this;
}
@Override
public Builder<T> useAdaGrad(boolean useAdaGrad) {
super.useAdaGrad(useAdaGrad);
return this;
}
@Override
public Builder<T> lr(double lr) {
super.lr(lr);
return this;
}
@Override
public Builder<T> gen(Random gen) {
super.gen(gen);
return this;
}
@Override
public Builder<T> seed(long seed) {
super.seed(seed);
return this;
}
public GloveWeightLookupTable<T> build() {
return new GloveWeightLookupTable<>(vocabCache, vectorLength, useAdaGrad, lr, gen, negative, xMax,
maxCount);
}
}
}

View File

@ -1,91 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove.count;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import java.io.File;
import java.io.PrintWriter;
/**
* @author raver119@gmail.com
*/
public class ASCIICoOccurrenceReader<T extends SequenceElement> implements CoOccurenceReader<T> {
private File file;
private PrintWriter writer;
private SentenceIterator iterator;
private VocabCache<T> vocabCache;
public ASCIICoOccurrenceReader(@NonNull File file, @NonNull VocabCache<T> vocabCache) {
this.vocabCache = vocabCache;
this.file = file;
try {
iterator = new PrefetchingSentenceIterator.Builder(new BasicLineIterator(file)).build();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public boolean hasMoreObjects() {
return iterator.hasNext();
}
/**
* Returns next CoOccurrenceWeight object
*
* PLEASE NOTE: This method can return null value.
* @return
*/
@Override
public CoOccurrenceWeight<T> nextObject() {
String line = iterator.nextSentence();
if (line == null || line.isEmpty()) {
return null;
}
String[] strings = line.split(" ");
CoOccurrenceWeight<T> object = new CoOccurrenceWeight<>();
object.setElement1(vocabCache.elementAtIndex(Integer.valueOf(strings[0])));
object.setElement2(vocabCache.elementAtIndex(Integer.valueOf(strings[1])));
object.setWeight(Double.parseDouble(strings[2]));
return object;
}
@Override
public void finish() {
try {
if (writer != null) {
writer.flush();
writer.close();
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@ -1,69 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove.count;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintWriter;
/**
* @author raver119@gmail.com
*/
public class ASCIICoOccurrenceWriter<T extends SequenceElement> implements CoOccurrenceWriter<T> {
private File file;
private PrintWriter writer;
public ASCIICoOccurrenceWriter(@NonNull File file) {
this.file = file;
try {
this.writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(file), 10 * 1024 * 1024));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void writeObject(CoOccurrenceWeight<T> object) {
StringBuilder builder = new StringBuilder(String.valueOf(object.getElement1().getIndex())).append(" ")
.append(String.valueOf(object.getElement2().getIndex())).append(" ")
.append(String.valueOf(object.getWeight()));
writer.println(builder.toString());
}
@Override
public void queueObject(CoOccurrenceWeight<T> object) {
throw new UnsupportedOperationException();
}
@Override
public void finish() {
try {
writer.flush();
} catch (Exception e) {
}
try {
writer.close();
} catch (Exception e) {
}
}
}

View File

@ -1,245 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove.count;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Binary implementation of CoOccurenceReader interface, used to provide off-memory storage for cooccurrence maps generated for GloVe
*
* @author raver119@gmail.com
*/
public class BinaryCoOccurrenceReader<T extends SequenceElement> implements CoOccurenceReader<T> {
private VocabCache<T> vocabCache;
private InputStream inputStream;
private File file;
private ArrayBlockingQueue<CoOccurrenceWeight<T>> buffer;
int workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
private StreamReaderThread readerThread;
private CountMap<T> countMap;
protected static final Logger logger = LoggerFactory.getLogger(BinaryCoOccurrenceReader.class);
public BinaryCoOccurrenceReader(@NonNull File file, @NonNull VocabCache<T> vocabCache, CountMap<T> map) {
this.vocabCache = vocabCache;
this.file = file;
this.countMap = map;
buffer = new ArrayBlockingQueue<>(200000);
try {
inputStream = new BufferedInputStream(new FileInputStream(this.file), 100 * 1024 * 1024);
//inputStream = new BufferedInputStream(new FileInputStream(file), 1024 * 1024);
readerThread = new StreamReaderThread(inputStream);
readerThread.start();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public boolean hasMoreObjects() {
if (!buffer.isEmpty())
return true;
try {
return readerThread.hasMoreObjects() || !buffer.isEmpty();
} catch (Exception e) {
throw new RuntimeException(e);
//return false;
}
}
@Override
public CoOccurrenceWeight<T> nextObject() {
if (!buffer.isEmpty()) {
return buffer.poll();
} else {
// buffer can be starved, or we're already at the end of file.
if (readerThread.hasMoreObjects()) {
try {
return buffer.poll(3, TimeUnit.SECONDS);
} catch (Exception e) {
return null;
}
}
}
return null;
/*
try {
CoOccurrenceWeight<T> ret = new CoOccurrenceWeight<>();
ret.setElement1(vocabCache.elementAtIndex(inputStream.readInt()));
ret.setElement2(vocabCache.elementAtIndex(inputStream.readInt()));
ret.setWeight(inputStream.readDouble());
return ret;
} catch (Exception e) {
return null;
}
*/
}
@Override
public void finish() {
try {
if (inputStream != null)
inputStream.close();
} catch (Exception e) {
//
}
}
private class StreamReaderThread extends Thread implements Runnable {
private InputStream stream;
private AtomicBoolean isReading = new AtomicBoolean(false);
public StreamReaderThread(@NonNull InputStream stream) {
this.stream = stream;
isReading.set(false);
}
@Override
public void run() {
try {
// we read pre-defined number of objects as byte array
byte[] array = new byte[16 * 500000];
while (true) {
int count = stream.read(array);
isReading.set(true);
if (count == 0)
break;
// now we deserialize them in separate threads to gain some speedup, if possible
List<AsyncDeserializationThread> threads = new ArrayList<>();
AtomicInteger internalPosition = new AtomicInteger(0);
for (int t = 0; t < workers; t++) {
threads.add(t, new AsyncDeserializationThread(t, array, buffer, internalPosition, count));
threads.get(t).start();
}
// we'll block this cycle untill all objects are fit into queue
for (int t = 0; t < workers; t++) {
try {
threads.get(t).join();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
isReading.set(false);
if (count < array.length)
break;
}
} catch (Exception e) {
isReading.set(false);
throw new RuntimeException(e);
}
}
public boolean hasMoreObjects() {
try {
return stream.available() > 0 || isReading.get();
} catch (Exception e) {
return false;
} finally {
}
}
}
/**
* Utility class that accepts byte array as input, and deserialize it into set of CoOccurrenceWeight objects
*/
private class AsyncDeserializationThread extends Thread implements Runnable {
private int threadId;
private byte[] arrayReference;
private ArrayBlockingQueue<CoOccurrenceWeight<T>> targetBuffer;
private AtomicInteger pointer;
private int limit;
public AsyncDeserializationThread(int threadId, @NonNull byte[] array,
@NonNull ArrayBlockingQueue<CoOccurrenceWeight<T>> targetBuffer,
@NonNull AtomicInteger sharedPointer, int limit) {
this.threadId = threadId;
this.arrayReference = array;
this.targetBuffer = targetBuffer;
this.pointer = sharedPointer;
this.limit = limit;
setName("AsynDeserialization thread " + this.threadId);
}
@Override
public void run() {
ByteBuffer bB = ByteBuffer.wrap(arrayReference);
int position = 0;
while ((position = pointer.getAndAdd(16)) < this.limit) {
if (position >= limit) {
continue;
}
int e1idx = bB.getInt(position);
int e2idx = bB.getInt(position + 4);
double eW = bB.getDouble(position + 8);
CoOccurrenceWeight<T> object = new CoOccurrenceWeight<>();
object.setElement1(vocabCache.elementAtIndex(e1idx));
object.setElement2(vocabCache.elementAtIndex(e2idx));
if (countMap != null) {
double mW = countMap.getCount(object.getElement1(), object.getElement2());
if (mW > 0) {
eW += mW;
countMap.removePair(object.getElement1(), object.getElement2());
}
}
object.setWeight(eW);
try {
targetBuffer.put(object);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
}

View File

@ -1,78 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove.count;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
/**
* @author raver119@gmail.com
*/
public class BinaryCoOccurrenceWriter<T extends SequenceElement> implements CoOccurrenceWriter<T> {
private File file;
private DataOutputStream outputStream;
private static final Logger log = LoggerFactory.getLogger(BinaryCoOccurrenceWriter.class);
public BinaryCoOccurrenceWriter(@NonNull File file) {
this.file = file;
try {
outputStream = new DataOutputStream(
new BufferedOutputStream(new FileOutputStream(file), 100 * 1024 * 1024));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void writeObject(@NonNull CoOccurrenceWeight<T> object) {
try {
// log.info("Saving objects: { [" +object.getElement1().getIndex() +"], [" + object.getElement2().getIndex() + "] }");
outputStream.writeInt(object.getElement1().getIndex());
outputStream.writeInt(object.getElement2().getIndex());
outputStream.writeDouble(object.getWeight());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void queueObject(CoOccurrenceWeight<T> object) {
throw new UnsupportedOperationException();
}
@Override
public void finish() {
try {
outputStream.flush();
} catch (Exception e) {
}
try {
outputStream.close();
} catch (Exception e) {
}
}
}

View File

@ -1,34 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove.count;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
/**
* Created by raver on 24.12.2015.
*/
public interface CoOccurenceReader<T extends SequenceElement> {
/*
Storage->Memory merging part
*/
boolean hasMoreObjects();
CoOccurrenceWeight<T> nextObject();
void finish();
}

View File

@ -1,54 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove.count;
import lombok.Data;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
/**
* Simple POJO holding pairs of elements and their respective weights, used in GloVe -> CoOccurrence
*
* @author raver119@gmail.com
*/
@Data
public class CoOccurrenceWeight<T extends SequenceElement> {
private T element1;
private T element2;
private double weight;
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
CoOccurrenceWeight<?> that = (CoOccurrenceWeight<?>) o;
if (element1 != null ? !element1.equals(that.element1) : that.element1 != null)
return false;
return element2 != null ? element2.equals(that.element2) : that.element2 == null;
}
@Override
public int hashCode() {
int result = element1 != null ? element1.hashCode() : 0;
result = 31 * result + (element2 != null ? element2.hashCode() : 0);
return result;
}
}

View File

@ -1,43 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove.count;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
/**
* Created by fartovii on 25.12.15.
*/
public interface CoOccurrenceWriter<T extends SequenceElement> {
/**
* This method implementations should write out objects immediately
* @param object
*/
void writeObject(CoOccurrenceWeight<T> object);
/**
* This method implementations should queue objects for writing out.
*
* @param object
*/
void queueObject(CoOccurrenceWeight<T> object);
/**
* Implementations of this method should close everything they use, before eradication
*/
void finish();
}

View File

@ -1,99 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove.count;
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.nd4j.common.primitives.Pair;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* Drop-in replacement for CounterMap
*
* WORK IN PROGRESS, PLEASE DO NOT USE
*
* @author raver119@gmail.com
*/
public class CountMap<T extends SequenceElement> {
private volatile Map<Pair<T, T>, AtomicDouble> backingMap = new ConcurrentHashMap<>();
public CountMap() {
// placeholder
}
public void incrementCount(T element1, T element2, double weight) {
Pair<T, T> tempEntry = new Pair<>(element1, element2);
if (backingMap.containsKey(tempEntry)) {
backingMap.get(tempEntry).addAndGet(weight);
} else {
backingMap.put(tempEntry, new AtomicDouble(weight));
}
}
public void removePair(T element1, T element2) {
Pair<T, T> tempEntry = new Pair<>(element1, element2);
backingMap.remove(tempEntry);
}
public void removePair(Pair<T, T> pair) {
backingMap.remove(pair);
}
public double getCount(T element1, T element2) {
Pair<T, T> tempEntry = new Pair<>(element1, element2);
if (backingMap.containsKey(tempEntry)) {
return backingMap.get(tempEntry).get();
} else
return 0;
}
public double getCount(Pair<T, T> pair) {
if (backingMap.containsKey(pair)) {
return backingMap.get(pair).get();
} else
return 0;
}
public Iterator<Pair<T, T>> getPairIterator() {
return new Iterator<Pair<T, T>>() {
private Iterator<Pair<T, T>> iterator = backingMap.keySet().iterator();
@Override
public boolean hasNext() {
return iterator.hasNext();
}
@Override
public Pair<T, T> next() {
//MapEntry<T> entry = iterator.next();
return iterator.next(); //new Pair<>(entry.getElement1(), entry.getElement2());
}
@Override
public void remove() {
throw new UnsupportedOperationException("remove() isn't supported here");
}
};
}
public int size() {
return backingMap.size();
}
}

View File

@ -1,86 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove.count;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
* Simple circular counter, that circulates within 0...Limit, both inclusive
*
* @author raver119@gmail.com
*/
public class RoundCount {
private int limit = 0;
private int lower = 0;
private int value = 0;
private ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
/**
* Creates new RoundCount instance.
*
* @param limit Maximum top value for this counter. Inclusive.
*/
public RoundCount(int limit) {
this.limit = limit;
}
/**
* Creates new RoundCount instance.
*
* @param lower - Minimum value for this counter. Inclusive
* @param top - Maximum value for this counter. Inclusive.
*/
public RoundCount(int lower, int top) {
this.limit = top;
this.lower = lower;
}
public int previous() {
try {
lock.readLock().lock();
if (value == lower)
return limit;
else
return value - 1;
} finally {
lock.readLock().unlock();
}
}
public int get() {
try {
lock.readLock().lock();
return value;
} finally {
lock.readLock().unlock();
}
}
public void tick() {
try {
lock.writeLock().lock();
if (value == limit)
value = lower;
else
value++;
} finally {
lock.writeLock().unlock();
}
}
}

View File

@ -763,7 +763,7 @@ public class ParagraphVectors extends Word2Vec {
/**
* This method allows you to use pre-built WordVectors model (Word2Vec or GloVe) for ParagraphVectors.
* This method allows you to use pre-built WordVectors model (e.g. Word2Vec) for ParagraphVectors.
* Existing model will be transferred into new model before training starts.
*
* PLEASE NOTE: Non-normalized model is recommended to use here.

View File

@ -520,7 +520,7 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
}
/**
* This method allows you to use pre-built WordVectors model (SkipGram or GloVe) for DBOW sequence learning.
* This method allows you to use pre-built WordVectors model (e.g. SkipGram) for DBOW sequence learning.
* Existing model will be transferred into new model before training starts.
*
* PLEASE NOTE: This model has no effect for elements learning algorithms. Only sequence learning is affected.

View File

@ -1,101 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove;
import org.deeplearning4j.BaseDL4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabConstructor;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.common.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
/**
* @author raver119@gmail.com
*/
public class AbstractCoOccurrencesTest extends BaseDL4JTest {
private static final Logger log = LoggerFactory.getLogger(AbstractCoOccurrencesTest.class);
@Before
public void setUp() throws Exception {
}
@Test
public void testFit1() throws Exception {
ClassPathResource resource = new ClassPathResource("other/oneline.txt");
File file = resource.getFile();
AbstractCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
BasicLineIterator underlyingIterator = new BasicLineIterator(file);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
SentenceTransformer transformer =
new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator =
new AbstractSequenceIterator.Builder<>(transformer).build();
VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>()
.addSource(sequenceIterator, 1).setTargetVocabCache(vocabCache).build();
constructor.buildJointVocabulary(false, true);
AbstractCoOccurrences<VocabWord> coOccurrences = new AbstractCoOccurrences.Builder<VocabWord>()
.iterate(sequenceIterator).vocabCache(vocabCache).symmetric(false).windowSize(15).build();
coOccurrences.fit();
//List<Pair<VocabWord, VocabWord>> list = coOccurrences.i();
Iterator<Pair<Pair<VocabWord, VocabWord>, Double>> iterator = coOccurrences.iterator();
assertNotEquals(null, iterator);
int cnt = 0;
List<Pair<VocabWord, VocabWord>> list = new ArrayList<>();
while (iterator.hasNext()) {
Pair<Pair<VocabWord, VocabWord>, Double> pair = iterator.next();
list.add(pair.getFirst());
cnt++;
}
log.info("CoOccurrences: " + list);
assertEquals(16, list.size());
assertEquals(16, cnt);
}
}

View File

@ -1,137 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove;
import org.deeplearning4j.BaseDL4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.resources.Resources;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.util.Collection;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* Created by agibsonccc on 12/3/14.
*/
public class GloveTest extends BaseDL4JTest {
private static final Logger log = LoggerFactory.getLogger(GloveTest.class);
private Glove glove;
private SentenceIterator iter;
@Before
public void before() throws Exception {
ClassPathResource resource = new ClassPathResource("/raw_sentences.txt");
File file = resource.getFile();
iter = new LineSentenceIterator(file);
iter.setPreProcessor(new SentencePreProcessor() {
@Override
public String preProcess(String sentence) {
return sentence.toLowerCase();
}
});
}
@Ignore
@Test
public void testGlove() throws Exception {
/*
glove = new Glove.Builder().iterate(iter).symmetric(true).shuffle(true)
.minWordFrequency(1).iterations(10).learningRate(0.1)
.layerSize(300)
.build();
glove.fit();
Collection<String> words = glove.wordsNearest("day", 20);
log.info("Nearest words to 'day': " + words);
assertTrue(words.contains("week"));
*/
}
@Ignore
@Test
public void testGloVe1() throws Exception {
File inputFile = Resources.asFile("big/raw_sentences.txt");
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Glove glove = new Glove.Builder().iterate(iter).tokenizerFactory(t).alpha(0.75).learningRate(0.1).epochs(45)
.xMax(100).shuffle(true).symmetric(true).build();
glove.fit();
double simD = glove.similarity("day", "night");
double simP = glove.similarity("best", "police");
log.info("Day/night similarity: " + simD);
log.info("Best/police similarity: " + simP);
Collection<String> words = glove.wordsNearest("day", 10);
log.info("Nearest words to 'day': " + words);
assertTrue(simD > 0.7);
// actually simP should be somewhere at 0
assertTrue(simP < 0.5);
assertTrue(words.contains("night"));
assertTrue(words.contains("year"));
assertTrue(words.contains("week"));
File tempFile = File.createTempFile("glove", "temp");
tempFile.deleteOnExit();
INDArray day1 = glove.getWordVectorMatrix("day").dup();
WordVectorSerializer.writeWordVectors(glove, tempFile);
WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
INDArray day2 = vectors.getWordVectorMatrix("day").dup();
assertEquals(day1, day2);
tempFile.delete();
}
}

View File

@ -1,156 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove.count;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import static org.junit.Assert.assertNotEquals;
/**
* Created by fartovii on 25.12.15.
*/
public class BinaryCoOccurrenceReaderTest extends BaseDL4JTest {
private static final Logger log = LoggerFactory.getLogger(BinaryCoOccurrenceReaderTest.class);
@Before
public void setUp() throws Exception {
}
@Test
public void testHasMoreObjects1() throws Exception {
File tempFile = File.createTempFile("tmp", "tmp");
tempFile.deleteOnExit();
VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
VocabWord word1 = new VocabWord(1.0, "human");
VocabWord word2 = new VocabWord(2.0, "animal");
VocabWord word3 = new VocabWord(3.0, "unknown");
vocabCache.addToken(word1);
vocabCache.addToken(word2);
vocabCache.addToken(word3);
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
BinaryCoOccurrenceWriter<VocabWord> writer = new BinaryCoOccurrenceWriter<>(tempFile);
CoOccurrenceWeight<VocabWord> object1 = new CoOccurrenceWeight<>();
object1.setElement1(word1);
object1.setElement2(word2);
object1.setWeight(3.14159265);
writer.writeObject(object1);
CoOccurrenceWeight<VocabWord> object2 = new CoOccurrenceWeight<>();
object2.setElement1(word2);
object2.setElement2(word3);
object2.setWeight(0.197);
writer.writeObject(object2);
writer.finish();
BinaryCoOccurrenceReader<VocabWord> reader = new BinaryCoOccurrenceReader<>(tempFile, vocabCache, null);
CoOccurrenceWeight<VocabWord> r1 = reader.nextObject();
log.info("Object received: " + r1);
assertNotEquals(null, r1);
r1 = reader.nextObject();
log.info("Object received: " + r1);
assertNotEquals(null, r1);
}
@Test
public void testHasMoreObjects2() throws Exception {
File tempFile = File.createTempFile("tmp", "tmp");
tempFile.deleteOnExit();
VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
VocabWord word1 = new VocabWord(1.0, "human");
VocabWord word2 = new VocabWord(2.0, "animal");
VocabWord word3 = new VocabWord(3.0, "unknown");
vocabCache.addToken(word1);
vocabCache.addToken(word2);
vocabCache.addToken(word3);
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
BinaryCoOccurrenceWriter<VocabWord> writer = new BinaryCoOccurrenceWriter<>(tempFile);
CoOccurrenceWeight<VocabWord> object1 = new CoOccurrenceWeight<>();
object1.setElement1(word1);
object1.setElement2(word2);
object1.setWeight(3.14159265);
writer.writeObject(object1);
CoOccurrenceWeight<VocabWord> object2 = new CoOccurrenceWeight<>();
object2.setElement1(word2);
object2.setElement2(word3);
object2.setWeight(0.197);
writer.writeObject(object2);
CoOccurrenceWeight<VocabWord> object3 = new CoOccurrenceWeight<>();
object3.setElement1(word1);
object3.setElement2(word3);
object3.setWeight(0.001);
writer.writeObject(object3);
writer.finish();
BinaryCoOccurrenceReader<VocabWord> reader = new BinaryCoOccurrenceReader<>(tempFile, vocabCache, null);
CoOccurrenceWeight<VocabWord> r1 = reader.nextObject();
log.info("Object received: " + r1);
assertNotEquals(null, r1);
r1 = reader.nextObject();
log.info("Object received: " + r1);
assertNotEquals(null, r1);
r1 = reader.nextObject();
log.info("Object received: " + r1);
assertNotEquals(null, r1);
}
}

View File

@ -1,90 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.glove.count;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
/**
* Created by fartovii on 23.12.15.
*/
public class RoundCountTest extends BaseDL4JTest {
@Before
public void setUp() throws Exception {
}
@Test
public void testGet1() throws Exception {
RoundCount count = new RoundCount(1);
assertEquals(0, count.get());
count.tick();
assertEquals(1, count.get());
count.tick();
assertEquals(0, count.get());
}
@Test
public void testGet2() throws Exception {
RoundCount count = new RoundCount(3);
assertEquals(0, count.get());
count.tick();
assertEquals(1, count.get());
count.tick();
assertEquals(2, count.get());
count.tick();
assertEquals(3, count.get());
count.tick();
assertEquals(0, count.get());
}
@Test
public void testPrevious1() throws Exception {
RoundCount count = new RoundCount(3);
assertEquals(0, count.get());
assertEquals(3, count.previous());
count.tick();
assertEquals(1, count.get());
assertEquals(0, count.previous());
count.tick();
assertEquals(2, count.get());
assertEquals(1, count.previous());
count.tick();
assertEquals(3, count.get());
assertEquals(2, count.previous());
count.tick();
assertEquals(0, count.get());
assertEquals(3, count.previous());
}
}

View File

@ -21,12 +21,10 @@ import lombok.Getter;
import lombok.Setter;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.BaseDL4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe;
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
@ -55,6 +53,7 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.slf4j.Logger;
@ -270,65 +269,6 @@ public class SequenceVectorsTest extends BaseDL4JTest {
.epochs(1).resetModel(false).trainElementsRepresentation(false).build();
}
@Ignore
@Test
public void testGlove1() throws Exception {
logger.info("Max available memory: " + Runtime.getRuntime().maxMemory());
ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
File file = resource.getFile();
BasicLineIterator underlyingIterator = new BasicLineIterator(file);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
SentenceTransformer transformer =
new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator =
new AbstractSequenceIterator.Builder<>(transformer).build();
VectorsConfiguration configuration = new VectorsConfiguration();
configuration.setWindow(5);
configuration.setLearningRate(0.06);
configuration.setLayersSize(100);
SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(configuration)
.iterate(sequenceIterator).iterations(1).epochs(45)
.elementsLearningAlgorithm(new GloVe.Builder<VocabWord>().shuffle(true).symmetric(true)
.learningRate(0.05).alpha(0.75).xMax(100.0).build())
.resetModel(true).trainElementsRepresentation(true).trainSequencesRepresentation(false).build();
vectors.fit();
double sim = vectors.similarity("day", "night");
logger.info("Day/night similarity: " + sim);
sim = vectors.similarity("day", "another");
logger.info("Day/another similarity: " + sim);
sim = vectors.similarity("night", "year");
logger.info("Night/year similarity: " + sim);
sim = vectors.similarity("night", "me");
logger.info("Night/me similarity: " + sim);
sim = vectors.similarity("day", "know");
logger.info("Day/know similarity: " + sim);
sim = vectors.similarity("best", "police");
logger.info("Best/police similarity: " + sim);
Collection<String> labels = vectors.wordsNearest("day", 10);
logger.info("Nearest labels to 'day': " + labels);
sim = vectors.similarity("day", "night");
assertTrue(sim > 0.6d);
}
@Test
@Ignore
public void testDeepWalk() throws Exception {

View File

@ -49,8 +49,8 @@ import java.io.File;
import java.util.Collection;
import java.util.concurrent.Callable;
import static org.awaitility.Awaitility.await;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@Slf4j
@ -206,12 +206,6 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
await()
.until(new Callable<Boolean>() {
@Override
public Boolean call() {
return net.params().equalsWithEps(restored.params(), 2e-3);
}
});
assertTrue(net.params().equalsWithEps(restored.params(), 2e-3));
}
}

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -247,10 +248,8 @@ public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.c
@Override
public int[] predict(INDArray input) {
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
int[] ret = new int[input.rows()];
for (int i = 0; i < ret.length; i++)
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
return ret;
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
return output.argMax(1).toIntVector();
}
/**

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
@ -251,10 +252,8 @@ public class LossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.LossL
@Override
public int[] predict(INDArray input) {
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
int[] ret = new int[input.rows()];
for (int i = 0; i < ret.length; i++)
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
return ret;
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
return output.argMax(1).toIntVector();
}
/**

View File

@ -2220,14 +2220,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
if (d.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
int[] ret = new int[(int) d.size(0)];
if (d.isRowVectorOrScalar())
ret[0] = Nd4j.getBlasWrapper().iamax(output);
else {
for (int i = 0; i < ret.length; i++)
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
}
return ret;
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
return output.argMax(1).toIntVector();
}
/**

View File

@ -1,280 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.spark.models.embeddings.glove;
import org.apache.commons.math3.util.FastMath;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCalculator;
import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCounts;
import org.deeplearning4j.spark.text.functions.TextPipeline;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.legacy.AdaGrad;
import org.nd4j.common.primitives.CounterMap;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;
import static org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables.*;
/**
* Spark glove
*
* @author Adam Gibson
*/
public class Glove implements Serializable {
private Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast;
private String tokenizerFactoryClazz = DefaultTokenizerFactory.class.getName();
private boolean symmetric = true;
private int windowSize = 15;
private int iterations = 300;
private static Logger log = LoggerFactory.getLogger(Glove.class);
/**
*
* @param tokenizerFactoryClazz the fully qualified class name of the tokenizer
* @param symmetric whether the co occurrence counts should be symmetric
* @param windowSize the window size for co occurrence
* @param iterations the number of iterations
*/
public Glove(String tokenizerFactoryClazz, boolean symmetric, int windowSize, int iterations) {
this.tokenizerFactoryClazz = tokenizerFactoryClazz;
this.symmetric = symmetric;
this.windowSize = windowSize;
this.iterations = iterations;
}
/**
*
* @param symmetric whether the co occurrence counts should be symmetric
* @param windowSize the window size for co occurrence
* @param iterations the number of iterations
*/
public Glove(boolean symmetric, int windowSize, int iterations) {
this.symmetric = symmetric;
this.windowSize = windowSize;
this.iterations = iterations;
}
private Pair<INDArray, Float> update(AdaGrad weightAdaGrad, AdaGrad biasAdaGrad, INDArray syn0, INDArray bias,
VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) {
//gradient for word vectors
INDArray grad1 = contextVector.mul(gradient);
INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape());
wordVector.subi(update);
double w1Bias = bias.getDouble(w1.getIndex());
double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape());
double update2 = w1Bias - biasGradient;
bias.putScalar(w1.getIndex(), bias.getDouble(w1.getIndex()) - update2);
return new Pair<>(update, (float) update2);
}
/**
* Train on the corpus
* @param rdd the rdd to train
* @return the vocab and weights
*/
public Pair<VocabCache<VocabWord>, GloveWeightLookupTable> train(JavaRDD<String> rdd) throws Exception {
// Each `train()` can use different parameters
final JavaSparkContext sc = new JavaSparkContext(rdd.context());
final SparkConf conf = sc.getConf();
final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class);
final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class);
final double negative = assignVar(NEGATIVE, conf, Double.class);
final int numWords = assignVar(NUM_WORDS, conf, Integer.class);
final int window = assignVar(WINDOW, conf, Integer.class);
final double alpha = assignVar(ALPHA, conf, Double.class);
final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class);
final int iterations = assignVar(ITERATIONS, conf, Integer.class);
final int nGrams = assignVar(N_GRAMS, conf, Integer.class);
final String tokenizer = assignVar(TOKENIZER, conf, String.class);
final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class);
final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class);
Map<String, Object> tokenizerVarMap = new HashMap<String, Object>() {
{
put("numWords", numWords);
put("nGrams", nGrams);
put("tokenizer", tokenizer);
put("tokenPreprocessor", tokenPreprocessor);
put("removeStop", removeStop);
}
};
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
TextPipeline pipeline = new TextPipeline(rdd, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
// Get total word count
Long totalWordCount = pipeline.getTotalWordCount();
VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD = pipeline.getSentenceWordsCountRDD();
final Pair<VocabCache<VocabWord>, Long> vocabAndNumWords = new Pair<>(vocabCache, totalWordCount);
vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());
final GloveWeightLookupTable gloveWeightLookupTable = new GloveWeightLookupTable.Builder()
.cache(vocabAndNumWords.getFirst()).lr(conf.getDouble(GlovePerformer.ALPHA, 0.01))
.maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100))
.vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300))
.xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75)).build();
gloveWeightLookupTable.resetWeights();
gloveWeightLookupTable.getBiasAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().rows());
gloveWeightLookupTable.getWeightAdaGrad().historicalGradient =
Nd4j.ones(gloveWeightLookupTable.getSyn0().shape());
log.info("Created lookup table of size " + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
CounterMap<String, String> coOccurrenceCounts = sentenceWordsCountRDD
.map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize))
.fold(new CounterMap<String, String>(), new CoOccurrenceCounts());
Iterator<Pair<String, String>> pair2 = coOccurrenceCounts.getIterator();
List<Triple<String, String, Float>> counts = new ArrayList<>();
while (pair2.hasNext()) {
Pair<String, String> next = pair2.next();
if (coOccurrenceCounts.getCount(next.getFirst(), next.getSecond()) > gloveWeightLookupTable.getMaxCount()) {
coOccurrenceCounts.setCount(next.getFirst(), next.getSecond(),
(float) gloveWeightLookupTable.getMaxCount());
}
counts.add(new Triple<>(next.getFirst(), next.getSecond(),
(float) coOccurrenceCounts.getCount(next.getFirst(), next.getSecond())));
}
log.info("Calculated co occurrences");
JavaRDD<Triple<String, String, Float>> parallel = sc.parallelize(counts);
JavaPairRDD<String, Tuple2<String, Float>> pairs = parallel
.mapToPair(new PairFunction<Triple<String, String, Float>, String, Tuple2<String, Float>>() {
@Override
public Tuple2<String, Tuple2<String, Float>> call(
Triple<String, String, Float> stringStringDoubleTriple) throws Exception {
return new Tuple2<>(stringStringDoubleTriple.getFirst(),
new Tuple2<>(stringStringDoubleTriple.getSecond(),
stringStringDoubleTriple.getThird()));
}
});
JavaPairRDD<VocabWord, Tuple2<VocabWord, Float>> pairsVocab = pairs.mapToPair(
new PairFunction<Tuple2<String, Tuple2<String, Float>>, VocabWord, Tuple2<VocabWord, Float>>() {
@Override
public Tuple2<VocabWord, Tuple2<VocabWord, Float>> call(
Tuple2<String, Tuple2<String, Float>> stringTuple2Tuple2) throws Exception {
VocabWord w1 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1());
VocabWord w2 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1());
return new Tuple2<>(w1, new Tuple2<>(w2, stringTuple2Tuple2._2()._2()));
}
});
for (int i = 0; i < iterations; i++) {
JavaRDD<GloveChange> change =
pairsVocab.map(new Function<Tuple2<VocabWord, Tuple2<VocabWord, Float>>, GloveChange>() {
@Override
public GloveChange call(
Tuple2<VocabWord, Tuple2<VocabWord, Float>> vocabWordTuple2Tuple2)
throws Exception {
VocabWord w1 = vocabWordTuple2Tuple2._1();
VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
INDArray bias = gloveWeightLookupTable.getBias();
double score = vocabWordTuple2Tuple2._2()._2();
double xMax = gloveWeightLookupTable.getxMax();
double maxCount = gloveWeightLookupTable.getMaxCount();
//w1 * w2 + bias
double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
double weight = FastMath.pow(Math.min(1.0, (score / maxCount)), xMax);
double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
if (Double.isNaN(fDiff))
fDiff = Nd4j.EPS_THRESHOLD;
//amount of change
double gradient = fDiff;
Pair<INDArray, Float> w1Update = update(gloveWeightLookupTable.getWeightAdaGrad(),
gloveWeightLookupTable.getBiasAdaGrad(),
gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(),
w1, w1Vector, w2Vector, gradient);
Pair<INDArray, Float> w2Update = update(gloveWeightLookupTable.getWeightAdaGrad(),
gloveWeightLookupTable.getBiasAdaGrad(),
gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(),
w2, w2Vector, w1Vector, gradient);
return new GloveChange(w1, w2, w1Update.getFirst(), w2Update.getFirst(),
w1Update.getSecond(), w2Update.getSecond(), fDiff,
gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient()
.slice(w1.getIndex()),
gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient()
.slice(w2.getIndex()),
gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient()
.getDouble(w2.getIndex()),
gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient()
.getDouble(w1.getIndex()));
}
});
List<GloveChange> gloveChanges = change.collect();
double error = 0.0;
for (GloveChange change2 : gloveChanges) {
change2.apply(gloveWeightLookupTable);
error += change2.getError();
}
List l = pairsVocab.collect();
Collections.shuffle(l);
pairsVocab = sc.parallelizePairs(l);
log.info("Error at iteration " + i + " was " + error);
}
return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable);
}
}

View File

@ -1,163 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.spark.models.embeddings.glove;
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.Serializable;
/**
* @author Adam Gibson
*/
public class GloveChange implements Serializable {
private VocabWord w1, w2;
private INDArray w1Update, w2Update;
private double w1BiasUpdate, w2BiasUpdate;
private double error;
private INDArray w1History, w2History;
private double w1BiasHistory, w2BiasHistory;
public GloveChange(VocabWord w1, VocabWord w2, INDArray w1Update, INDArray w2Update, double w1BiasUpdate,
double w2BiasUpdate, double error, INDArray w1History, INDArray w2History, double w1BiasHistory,
double w2BiasHistory) {
this.w1 = w1;
this.w2 = w2;
this.w1Update = w1Update;
this.w2Update = w2Update;
this.w1BiasUpdate = w1BiasUpdate;
this.w2BiasUpdate = w2BiasUpdate;
this.error = error;
this.w1History = w1History;
this.w2History = w2History;
this.w1BiasHistory = w1BiasHistory;
this.w2BiasHistory = w2BiasHistory;
}
/**
* Apply the changes to the table
* @param table
*/
public void apply(GloveWeightLookupTable table) {
table.getBias().putScalar(w1.getIndex(), table.getBias().getDouble(w1.getIndex()) - w1BiasUpdate);
table.getBias().putScalar(w2.getIndex(), table.getBias().getDouble(w2.getIndex()) - w2BiasUpdate);
table.getSyn0().slice(w1.getIndex()).subi(w1Update);
table.getSyn0().slice(w2.getIndex()).subi(w2Update);
table.getWeightAdaGrad().getHistoricalGradient().slice(w1.getIndex()).addi(w1History);
table.getWeightAdaGrad().getHistoricalGradient().slice(w2.getIndex()).addi(w2History);
table.getBiasAdaGrad().getHistoricalGradient().putScalar(w1.getIndex(),
table.getBiasAdaGrad().getHistoricalGradient().getDouble(w1.getIndex()) + w1BiasHistory);
table.getBiasAdaGrad().getHistoricalGradient().putScalar(w2.getIndex(),
table.getBiasAdaGrad().getHistoricalGradient().getDouble(w2.getIndex()) + w1BiasHistory);
}
public INDArray getW1History() {
return w1History;
}
public void setW1History(INDArray w1History) {
this.w1History = w1History;
}
public INDArray getW2History() {
return w2History;
}
public void setW2History(INDArray w2History) {
this.w2History = w2History;
}
public double getW1BiasHistory() {
return w1BiasHistory;
}
public void setW1BiasHistory(double w1BiasHistory) {
this.w1BiasHistory = w1BiasHistory;
}
public double getW2BiasHistory() {
return w2BiasHistory;
}
public void setW2BiasHistory(double w2BiasHistory) {
this.w2BiasHistory = w2BiasHistory;
}
public VocabWord getW1() {
return w1;
}
public void setW1(VocabWord w1) {
this.w1 = w1;
}
public VocabWord getW2() {
return w2;
}
public void setW2(VocabWord w2) {
this.w2 = w2;
}
public INDArray getW1Update() {
return w1Update;
}
public void setW1Update(INDArray w1Update) {
this.w1Update = w1Update;
}
public INDArray getW2Update() {
return w2Update;
}
public void setW2Update(INDArray w2Update) {
this.w2Update = w2Update;
}
public double getW1BiasUpdate() {
return w1BiasUpdate;
}
public void setW1BiasUpdate(double w1BiasUpdate) {
this.w1BiasUpdate = w1BiasUpdate;
}
public double getW2BiasUpdate() {
return w2BiasUpdate;
}
public void setW2BiasUpdate(double w2BiasUpdate) {
this.w2BiasUpdate = w2BiasUpdate;
}
public double getError() {
return error;
}
public void setError(double error) {
this.error = error;
}
@Override
public String toString() {
return w1.getIndex() + "," + w2.getIndex() + " error " + error;
}
}

View File

@ -1,171 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.spark.models.embeddings.glove;
import org.apache.spark.broadcast.Broadcast;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.common.primitives.CounterMap;
import java.io.Serializable;
/**
* @author Adam Gibson
*/
public class GloveParam implements Serializable {
private int vectorLength;
private boolean useAdaGrad;
private double lr;
private Random gen;
private double negative;
private double xMax;
private double maxCount;
private Broadcast<CounterMap<String, String>> coOccurrenceCounts;
public GloveParam(int vectorLength, boolean useAdaGrad, double lr, Random gen, double negative, double xMax,
double maxCount, Broadcast<CounterMap<String, String>> coOccurrenceCounts) {
this.vectorLength = vectorLength;
this.useAdaGrad = useAdaGrad;
this.lr = lr;
this.gen = gen;
this.negative = negative;
this.xMax = xMax;
this.maxCount = maxCount;
this.coOccurrenceCounts = coOccurrenceCounts;
}
public int getVectorLength() {
return vectorLength;
}
public void setVectorLength(int vectorLength) {
this.vectorLength = vectorLength;
}
public boolean isUseAdaGrad() {
return useAdaGrad;
}
public void setUseAdaGrad(boolean useAdaGrad) {
this.useAdaGrad = useAdaGrad;
}
public double getLr() {
return lr;
}
public void setLr(double lr) {
this.lr = lr;
}
public Random getGen() {
return gen;
}
public void setGen(Random gen) {
this.gen = gen;
}
public double getNegative() {
return negative;
}
public void setNegative(double negative) {
this.negative = negative;
}
public double getxMax() {
return xMax;
}
public void setxMax(double xMax) {
this.xMax = xMax;
}
public double getMaxCount() {
return maxCount;
}
public void setMaxCount(double maxCount) {
this.maxCount = maxCount;
}
public Broadcast<CounterMap<String, String>> getCoOccurrenceCounts() {
return coOccurrenceCounts;
}
public void setCoOccurrenceCounts(Broadcast<CounterMap<String, String>> coOccurrenceCounts) {
this.coOccurrenceCounts = coOccurrenceCounts;
}
public static class Builder {
private int vectorLength = 300;
private boolean useAdaGrad = true;
private double lr = 0.025;
private Random gen;
private double negative = 5;
private double xMax = 0.75;
private double maxCount = 100;
private Broadcast<CounterMap<String, String>> coOccurrenceCounts;
public Builder vectorLength(int vectorLength) {
this.vectorLength = vectorLength;
return this;
}
public Builder useAdaGrad(boolean useAdaGrad) {
this.useAdaGrad = useAdaGrad;
return this;
}
public Builder lr(double lr) {
this.lr = lr;
return this;
}
public Builder gen(Random gen) {
this.gen = gen;
return this;
}
public Builder negative(double negative) {
this.negative = negative;
return this;
}
public Builder xMax(double xMax) {
this.xMax = xMax;
return this;
}
public Builder maxCount(double maxCount) {
this.maxCount = maxCount;
return this;
}
public Builder coOccurrenceCounts(Broadcast<CounterMap<String, String>> coOccurrenceCounts) {
this.coOccurrenceCounts = coOccurrenceCounts;
return this;
}
public GloveParam build() {
return new GloveParam(vectorLength, useAdaGrad, lr, gen, negative, xMax, maxCount, coOccurrenceCounts);
}
}
}

View File

@ -1,48 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.spark.models.embeddings.glove;
import org.apache.spark.api.java.function.Function;
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.nd4j.common.primitives.Triple;
/**
* Base line glove performer
*
* @author Adam Gibson
*/
public class GlovePerformer implements Function<Triple<VocabWord, VocabWord, Double>, GloveChange> {
public final static String NAME_SPACE = "org.deeplearning4j.scaleout.perform.models.glove";
public final static String VECTOR_LENGTH = NAME_SPACE + ".length";
public final static String ALPHA = NAME_SPACE + ".alpha";
public final static String X_MAX = NAME_SPACE + ".xmax";
public final static String MAX_COUNT = NAME_SPACE + ".maxcount";
private GloveWeightLookupTable table;
public GlovePerformer(GloveWeightLookupTable table) {
this.table = table;
}
@Override
public GloveChange call(Triple<VocabWord, VocabWord, Double> pair) throws Exception {
return null;
}
}

View File

@ -1,42 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.spark.models.embeddings.glove;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.common.primitives.Triple;
/**
* Convert string to vocab words
*
* @author Adam Gibson
*/
public class VocabWordPairs implements Function<Triple<String, String, Double>, Triple<VocabWord, VocabWord, Double>> {
private Broadcast<VocabCache<VocabWord>> vocab;
public VocabWordPairs(Broadcast<VocabCache<VocabWord>> vocab) {
this.vocab = vocab;
}
@Override
public Triple<VocabWord, VocabWord, Double> call(Triple<String, String, Double> v1) throws Exception {
return new Triple<>((VocabWord) vocab.getValue().wordFor(v1.getFirst()),
(VocabWord) vocab.getValue().wordFor(v1.getSecond()), v1.getThird());
}
}

View File

@ -1,91 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.spark.models.embeddings.glove.cooccurrences;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.CounterMap;
import org.nd4j.common.primitives.Pair;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
/**
* Calculate co occurrences based on tokens
*
* @author Adam Gibson
*/
public class CoOccurrenceCalculator implements Function<Pair<List<String>, AtomicLong>, CounterMap<String, String>> {
private boolean symmetric = false;
private Broadcast<VocabCache<VocabWord>> vocab;
private int windowSize = 5;
public CoOccurrenceCalculator(boolean symmetric, Broadcast<VocabCache<VocabWord>> vocab, int windowSize) {
this.symmetric = symmetric;
this.vocab = vocab;
this.windowSize = windowSize;
}
@Override
public CounterMap<String, String> call(Pair<List<String>, AtomicLong> pair) throws Exception {
List<String> sentence = pair.getFirst();
CounterMap<String, String> coOCurreneCounts = new CounterMap<>();
VocabCache vocab = this.vocab.value();
for (int i = 0; i < sentence.size(); i++) {
int wordIdx = vocab.indexOf(sentence.get(i));
String w1 = ((VocabWord) vocab.wordFor(sentence.get(i))).getWord();
if (wordIdx < 0) // || w1.equals(Glove.UNK))
continue;
int windowStop = Math.min(i + windowSize + 1, sentence.size());
for (int j = i; j < windowStop; j++) {
int otherWord = vocab.indexOf(sentence.get(j));
String w2 = ((VocabWord) vocab.wordFor(sentence.get(j))).getWord();
if (vocab.indexOf(sentence.get(j)) < 0) // || w2.equals(Glove.UNK))
continue;
if (otherWord == wordIdx)
continue;
if (wordIdx < otherWord) {
coOCurreneCounts.incrementCount(sentence.get(i), sentence.get(j),
(float) (1.0 / (j - i + Nd4j.EPS_THRESHOLD)));
if (symmetric)
coOCurreneCounts.incrementCount(sentence.get(j), sentence.get(i),
(float) (1.0 / (j - i + Nd4j.EPS_THRESHOLD)));
} else {
float coCount = (float) (1.0 / (j - i + Nd4j.EPS_THRESHOLD));
coOCurreneCounts.incrementCount(sentence.get(j), sentence.get(i), (float) coCount);
if (symmetric)
coOCurreneCounts.incrementCount(sentence.get(i), sentence.get(j),
(float) (1.0 / (j - i + Nd4j.EPS_THRESHOLD)));
}
}
}
return coOCurreneCounts;
}
}

View File

@ -1,37 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.spark.models.embeddings.glove.cooccurrences;
import org.apache.spark.api.java.function.Function2;
import org.nd4j.common.primitives.CounterMap;
/**
* Co occurrence count reduction
* @author Adam Gibson
*/
public class CoOccurrenceCounts implements
Function2<CounterMap<String, String>, CounterMap<String, String>, CounterMap<String, String>> {
@Override
public CounterMap<String, String> call(CounterMap<String, String> v1, CounterMap<String, String> v2)
throws Exception {
v1.incrementAll(v2);
return v1;
}
}

View File

@ -1,62 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.spark.models.embeddings.glove;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.nd4j.common.io.ClassPathResource;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.text.BaseSparkTest;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.common.primitives.Pair;
import java.util.Collection;
import static org.junit.Assert.assertTrue;
/**
* Created by agibsonccc on 1/31/15.
*/
@Ignore
public class GloveTest extends BaseSparkTest {
@Test
public void testGlove() throws Exception {
Glove glove = new Glove(true, 5, 100);
JavaRDD<String> corpus = sc.textFile(new ClassPathResource("big/raw_sentences.txt").getFile().getAbsolutePath())
.map(new Function<String, String>() {
@Override
public String call(String s) throws Exception {
return s.toLowerCase();
}
});
Pair<VocabCache<VocabWord>, GloveWeightLookupTable> table = glove.train(corpus);
WordVectors vectors = WordVectorSerializer
.fromPair(new Pair<>((InMemoryLookupTable) table.getSecond(), (VocabCache) table.getFirst()));
Collection<String> words = vectors.wordsNearest("day", 20);
assertTrue(words.contains("week"));
}
}

View File

@ -26,6 +26,7 @@ import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
@ -42,8 +43,10 @@ import java.io.IOException;
*/
public class TestTransferStatsCollection extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Override
public long getTimeoutMilliseconds() {
return 90_000L;
}
@Test
public void test() throws IOException {
@ -62,9 +65,7 @@ public class TestTransferStatsCollection extends BaseDL4JTest {
new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
.setFeatureExtractor(0).build();
File dir = testDir.newFolder();
File f = new File(dir, "dl4jTestTransferStatsCollection.bin");
net2.setListeners(new StatsListener(new FileStatsStorage(f)));
net2.setListeners(new StatsListener(new InMemoryStatsStorage()));
//Previosuly: failed on frozen layers
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));

View File

@ -88,7 +88,7 @@ namespace sd {
cudaFree(_allocationPointer);
if (_scalarPointer != nullptr)
cudaFree(_scalarPointer);
cudaFreeHost(_scalarPointer);
if (_allocationPointer != nullptr)
cudaFree(_reductionPointer);

View File

@ -243,9 +243,6 @@ __host__ void ReduceBoolFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStrea
int *dimension, int dimensionLength,
void *reductionPointer,
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
nd4j_printf("Step A%i\n", -1);
if(shape::isEmpty(hXShapeInfo)) {
if(shape::isEmpty(hZShapeInfo))

View File

@ -515,8 +515,8 @@ BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT cudaDecodeBitmapGeneric, (dim3 &
template <bool storeSum, bool isNP2>
__host__ void prescanLauncher(dim3 &blocks, dim3 &threads, int shmem, cudaStream_t *stream, int *g_odata, const int *g_idata, int *g_blockSums, int n, int blockIndex, int baseIndex) {
//printf("Prescan grid: <%i/%i/%i>; threads: <%i/%i/%i>; shareMemSize: %i\n", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, shmem);
prescan<storeSum, isNP2><<<blocks, threads, shmem, *stream>>>(g_odata, g_idata, g_blockSums, n, blockIndex, baseIndex);
sd::DebugHelper::checkErrorCode(stream, "prescan(...) failed");
};
template <typename S, typename T>

View File

@ -41,8 +41,12 @@ namespace sd {
else
numThreads = sd::floorPow2(numElements);
numThreads = sd::math::nd4j_max<int>(1, numThreads);
int numEltsPerBlock = numThreads * 2;
// if this is a non-power-of-2 array, the last block will be non-full
// compute the smallest power of 2 able to compute its scan.
int numEltsLastBlock =
@ -102,8 +106,6 @@ namespace sd {
} else {
sd::prescanLauncher<false, true>(grid, threads, sharedMemSize, stream, dZ, dX, 0, numElements, 0, 0);
}
sd::DebugHelper::checkErrorCode(stream, "prescanArray(...) failed");
}
static void encodeThresholdP2Int_(void **prs, int *dx, Nd4jLong N, int *dz) {

View File

@ -119,7 +119,7 @@ TEST_F(CudaBasicsTests1, TestPairwise_1) {
z.tickWriteHost();
for (int e = 0; e < z.lengthOf(); e++) {
nd4j_printf("step %i\n", e);
//nd4j_printf("step %i\n", e);
ASSERT_NEAR(exp.e<double>(e), z.e<double>(e), 1e-5);
}
}
@ -2822,7 +2822,7 @@ TEST_F(CudaBasicsTests1, execSummaryStats_2) {
// delete cuda stream
cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult);
}
/*
////////////////////////////////////////////////////////////////////////////
TEST_F(CudaBasicsTests1, execSummaryStats_3) {
@ -2876,6 +2876,7 @@ TEST_F(CudaBasicsTests1, execSummaryStats_3) {
// delete cuda stream
cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult);
}
*/
////////////////////////////////////////////////////////////////////////////
TEST_F(CudaBasicsTests1, execSummaryStatsScalar_1) {

View File

@ -1054,6 +1054,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
ASSERT_TRUE(testData.equalsTo(result));
}
/*
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) {
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
@ -1114,6 +1115,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) {
ASSERT_TRUE(expected.equalsTo(result));
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) {
NDArray input = NDArrayFactory::create<float>('c', {1, 3, 3, 1});
@ -1530,6 +1532,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) {
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
}
*/
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {

View File

@ -826,6 +826,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
case FLOAT:
return ((FloatIndexer) indexer).get(i);
case UINT32:
return ((UIntIndexer) indexer).get(i);
case INT:
return ((IntIndexer) indexer).get(i);
case BFLOAT16:
@ -866,10 +867,11 @@ public abstract class BaseDataBuffer implements DataBuffer {
return (long) ((Bfloat16Indexer) indexer).get(i);
case HALF:
return (long) ((HalfIndexer) indexer).get( i);
case UINT64:
case UINT64: //Fall through
case LONG:
return ((LongIndexer) indexer).get(i);
case UINT32:
return (long) ((UIntIndexer) indexer).get(i);
case INT:
return (long) ((IntIndexer) indexer).get(i);
case UINT16:
@ -906,6 +908,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
case BOOL:
return (short) (((BooleanIndexer) indexer).get(i) ? 1 : 0);
case UINT32:
return (short) ((UIntIndexer)indexer).get(i);
case INT:
return (short) ((IntIndexer) indexer).get(i);
case UINT16:
@ -943,6 +946,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
case BOOL:
return ((BooleanIndexer) indexer).get(i) ? 1.f : 0.f;
case UINT32:
return (float) ((UIntIndexer)indexer).get(i);
case INT:
return (float) ((IntIndexer) indexer).get(i);
case UINT16:
@ -957,7 +961,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
return (float) ((UByteIndexer) indexer).get(i);
case BYTE:
return (float) ((ByteIndexer) indexer).get(i);
case UINT64:
case UINT64: //Fall through
case LONG:
return (float) ((LongIndexer) indexer).get(i);
case FLOAT:
@ -978,6 +982,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
case BOOL:
return ((BooleanIndexer) indexer).get(i) ? 1 : 0;
case UINT32:
return (int)((UIntIndexer) indexer).get(i);
case INT:
return ((IntIndexer) indexer).get(i);
case BFLOAT16:
@ -992,7 +997,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
return ((UByteIndexer) indexer).get(i);
case BYTE:
return ((ByteIndexer) indexer).get(i);
case UINT64:
case UINT64: //Fall through
case LONG:
return (int) ((LongIndexer) indexer).get(i);
case FLOAT:
@ -1058,6 +1063,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
((ShortIndexer) indexer).put(i, (short) element);
break;
case UINT32:
((UIntIndexer) indexer).put(i, (long)element);
break;
case INT:
((IntIndexer) indexer).put(i, (int) element);
break;
@ -1104,6 +1111,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
((ShortIndexer) indexer).put(i, (short) element);
break;
case UINT32:
((UIntIndexer) indexer).put(i, (long)element);
break;
case INT:
((IntIndexer) indexer).put(i, (int) element);
break;
@ -1150,10 +1159,12 @@ public abstract class BaseDataBuffer implements DataBuffer {
((ShortIndexer) indexer).put(i, (short) element);
break;
case UINT32:
((UIntIndexer) indexer).put(i, element);
break;
case INT:
((IntIndexer) indexer).put(i, element);
break;
case UINT64:
case UINT64: //Fall through
case LONG:
((LongIndexer) indexer).put(i, element);
break;
@ -1195,8 +1206,10 @@ public abstract class BaseDataBuffer implements DataBuffer {
case SHORT:
((ShortIndexer) indexer).put(i, element ? (short) 1 : (short) 0);
break;
case INT:
case UINT32:
((UIntIndexer) indexer).put(i, element ? 1 : 0);
break;
case INT:
((IntIndexer) indexer).put(i, element ? 1 : 0);
break;
case UINT64:
@ -1242,6 +1255,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
((ShortIndexer) indexer).put(i, (short) element);
break;
case UINT32:
((UIntIndexer) indexer).put(i, element);
break;
case INT:
((IntIndexer) indexer).put(i, (int) element);
break;

View File

@ -324,6 +324,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
indexer = FloatIndexer.create((FloatPointer) pointer);
break;
case UINT32:
this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
@ -336,7 +339,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer();
indexer = HalfIndexer.create((ShortPointer) pointer);
break;
case UINT64:
case UINT64: //Fall through
case LONG:
this.pointer = new CudaPointer(hostPointer, length, 0).asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer);
@ -501,6 +504,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
indexer = FloatIndexer.create((FloatPointer) pointer);
break;
case UINT32:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
@ -513,7 +519,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer();
indexer = HalfIndexer.create((ShortPointer) pointer);
break;
case UINT64:
case UINT64: //Fall through
case LONG:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer);

View File

@ -121,6 +121,24 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();
setIndexer(ByteIndexer.create((BytePointer) pointer));
} else if(dataType() == DataType.FLOAT16){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(HalfIndexer.create((ShortPointer) pointer));
} else if(dataType() == DataType.BFLOAT16){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(Bfloat16Indexer.create((ShortPointer) pointer));
} else if(dataType() == DataType.BOOL){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBoolPointer();
setIndexer(BooleanIndexer.create((BooleanPointer) pointer));
} else if(dataType() == DataType.UINT16){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(UShortIndexer.create((ShortPointer) pointer));
} else if(dataType() == DataType.UINT32){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
setIndexer(UIntIndexer.create((IntPointer) pointer));
} else if (dataType() == DataType.UINT64) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
setIndexer(LongIndexer.create((LongPointer) pointer));
}
Nd4j.getDeallocatorService().pickObject(this);
@ -336,15 +354,13 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
} else if (dataType() == DataType.UINT32) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
// FIXME: we need unsigned indexer here
setIndexer(IntIndexer.create((IntPointer) pointer));
setIndexer(UIntIndexer.create((IntPointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.UINT64) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
// FIXME: we need unsigned indexer here
setIndexer(LongIndexer.create((LongPointer) pointer));
if (initialize)
@ -500,7 +516,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
// FIXME: need unsigned indexer here
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
setIndexer(IntIndexer.create((IntPointer) pointer));
setIndexer(UIntIndexer.create((IntPointer) pointer));
} else if (dataType() == DataType.UINT64) {
attached = true;

View File

@ -8395,6 +8395,25 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(e, z);
}
@Test
public void testCreateBufferFromByteBuffer(){
for(DataType dt : DataType.values()){
if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
continue;
// System.out.println(dt);
int lengthBytes = 256;
int lengthElements = lengthBytes / dt.width();
ByteBuffer bb = ByteBuffer.allocateDirect(lengthBytes);
DataBuffer db = Nd4j.createBuffer(bb, dt, lengthElements, 0);
INDArray arr = Nd4j.create(db, new long[]{lengthElements});
arr.toStringFull();
}
}
@Override
public char ordering() {
return 'c';