commit
9db86cec7a
|
@ -56,8 +56,10 @@ import static org.junit.Assert.*;
|
|||
*/
|
||||
public class RegressionTest050 extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType(){
|
||||
|
|
|
@ -64,6 +64,11 @@ public class RegressionTest060 extends BaseDL4JTest {
|
|||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Test
|
||||
public void regressionTestMLP1() throws Exception {
|
||||
|
||||
|
|
|
@ -64,6 +64,12 @@ public class RegressionTest071 extends BaseDL4JTest {
|
|||
public DataType getDataType(){
|
||||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Test
|
||||
public void regressionTestMLP1() throws Exception {
|
||||
|
||||
|
|
|
@ -64,6 +64,11 @@ public class RegressionTest080 extends BaseDL4JTest {
|
|||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Test
|
||||
public void regressionTestMLP1() throws Exception {
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ public class RegressionTest100a extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -52,7 +52,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -71,7 +71,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -58,7 +58,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -30,6 +30,11 @@ import static org.junit.Assert.assertTrue;
|
|||
*/
|
||||
public class TestDistributionDeserializer extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDistributionDeserializer() throws Exception {
|
||||
//Test current format:
|
||||
|
|
|
@ -46,6 +46,11 @@ public class TestDeepWalk extends BaseDL4JTest {
|
|||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 120_000L; //Increase timeout due to intermittently slow CI machines
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
public void testBasic() throws IOException {
|
||||
//Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions
|
||||
|
|
|
@ -84,12 +84,6 @@
|
|||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.awaitility</groupId>
|
||||
<artifactId>awaitility</artifactId>
|
||||
<version>4.0.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
|||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
/**
|
||||
* Implementations of this interface should contain element-related learning algorithms. Like skip-gram, cbow or glove
|
||||
* Implementations of this interface should contain element-related learning algorithms. Like skip-gram or cbow
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
|
|
|
@ -1,427 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.embeddings.learning.impl.elements;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
|
||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||
import org.deeplearning4j.models.glove.AbstractCoOccurrences;
|
||||
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.legacy.AdaGrad;
|
||||
import org.nd4j.common.primitives.Counter;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
/**
|
||||
* GloVe LearningAlgorithm implementation for SequenceVectors
|
||||
*
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class GloVe<T extends SequenceElement> implements ElementsLearningAlgorithm<T> {
|
||||
|
||||
private VocabCache<T> vocabCache;
|
||||
private AbstractCoOccurrences<T> coOccurrences;
|
||||
private WeightLookupTable<T> lookupTable;
|
||||
private VectorsConfiguration configuration;
|
||||
|
||||
private AtomicBoolean isTerminate = new AtomicBoolean(false);
|
||||
|
||||
private INDArray syn0;
|
||||
|
||||
private double xMax;
|
||||
private boolean shuffle;
|
||||
private boolean symmetric;
|
||||
protected double alpha = 0.75d;
|
||||
protected double learningRate = 0.0d;
|
||||
protected int maxmemory = 0;
|
||||
protected int batchSize = 1000;
|
||||
|
||||
private AdaGrad weightAdaGrad;
|
||||
private AdaGrad biasAdaGrad;
|
||||
private INDArray bias;
|
||||
|
||||
private int workers = Runtime.getRuntime().availableProcessors();
|
||||
|
||||
private int vectorLength;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(GloVe.class);
|
||||
|
||||
@Override
|
||||
public String getCodeName() {
|
||||
return "GloVe";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() {
|
||||
log.info("GloVe finalizer...");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable,
|
||||
@NonNull VectorsConfiguration configuration) {
|
||||
this.vocabCache = vocabCache;
|
||||
this.lookupTable = lookupTable;
|
||||
this.configuration = configuration;
|
||||
|
||||
this.syn0 = ((InMemoryLookupTable<T>) lookupTable).getSyn0();
|
||||
|
||||
|
||||
this.vectorLength = configuration.getLayersSize();
|
||||
|
||||
if (this.learningRate == 0.0d)
|
||||
this.learningRate = configuration.getLearningRate();
|
||||
|
||||
|
||||
|
||||
weightAdaGrad = new AdaGrad(new long[] {this.vocabCache.numWords() + 1, vectorLength}, learningRate);
|
||||
bias = Nd4j.create(syn0.rows());
|
||||
|
||||
biasAdaGrad = new AdaGrad(bias.shape(), this.learningRate);
|
||||
|
||||
// maxmemory = Runtime.getRuntime().maxMemory() - (vocabCache.numWords() * vectorLength * 2 * 8);
|
||||
|
||||
log.info("GloVe params: {Max Memory: [" + maxmemory + "], Learning rate: [" + this.learningRate + "], Alpha: ["
|
||||
+ alpha + "], xMax: [" + xMax + "], Symmetric: [" + symmetric + "], Shuffle: [" + shuffle
|
||||
+ "]}");
|
||||
}
|
||||
|
||||
/**
|
||||
* pretrain is used to build CoOccurrence matrix for GloVe algorithm
|
||||
* @param iterator
|
||||
*/
|
||||
@Override
|
||||
public void pretrain(@NonNull SequenceIterator<T> iterator) {
|
||||
// CoOccurence table should be built here
|
||||
coOccurrences = new AbstractCoOccurrences.Builder<T>()
|
||||
// TODO: symmetric should be handled via VectorsConfiguration
|
||||
.symmetric(this.symmetric).windowSize(configuration.getWindow()).iterate(iterator)
|
||||
.workers(workers).vocabCache(vocabCache).maxMemory(maxmemory).build();
|
||||
|
||||
coOccurrences.fit();
|
||||
}
|
||||
|
||||
public double learnSequence(Sequence<T> sequence, AtomicLong nextRandom, double learningRate,
|
||||
BatchSequences<T> batchSequences) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
/**
|
||||
* Learns sequence using GloVe algorithm
|
||||
*
|
||||
* @param sequence
|
||||
* @param nextRandom
|
||||
* @param learningRate
|
||||
*/
|
||||
@Override
|
||||
public synchronized double learnSequence(@NonNull Sequence<T> sequence, @NonNull AtomicLong nextRandom,
|
||||
double learningRate) {
|
||||
/*
|
||||
GloVe learning algorithm is implemented like a hack over settled ElementsLearningAlgorithm mechanics. It's called in SequenceVectors context, but actually only for the first call.
|
||||
All subsequent calls will met early termination condition, and will be successfully ignored. But since elements vectors will be updated within first call,
|
||||
this will allow compatibility with everything beyond this implementaton
|
||||
*/
|
||||
if (isTerminate.get())
|
||||
return 0;
|
||||
|
||||
final AtomicLong pairsCount = new AtomicLong(0);
|
||||
final Counter<Integer> errorCounter = new Counter<>();
|
||||
|
||||
//List<Pair<T, T>> coList = coOccurrences.coOccurrenceList();
|
||||
|
||||
for (int i = 0; i < configuration.getEpochs(); i++) {
|
||||
|
||||
// TODO: shuffle should be built in another way.
|
||||
//if (shuffle)
|
||||
//Collections.shuffle(coList);
|
||||
|
||||
Iterator<Pair<Pair<T, T>, Double>> pairs = coOccurrences.iterator();
|
||||
|
||||
List<GloveCalculationsThread> threads = new ArrayList<>();
|
||||
for (int x = 0; x < workers; x++) {
|
||||
threads.add(x, new GloveCalculationsThread(i, x, pairs, pairsCount, errorCounter));
|
||||
threads.get(x).start();
|
||||
}
|
||||
|
||||
|
||||
|
||||
for (int x = 0; x < workers; x++) {
|
||||
try {
|
||||
threads.get(x).join();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
log.info("Processed [" + pairsCount.get() + "] pairs, Error was [" + errorCounter.getCount(i) + "]");
|
||||
}
|
||||
|
||||
isTerminate.set(true);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Since GloVe is learning representations using elements CoOccurences, all training is done in GloVe class internally, so only first thread will execute learning process,
|
||||
* and the rest of parent threads will just exit learning process
|
||||
*
|
||||
* @return True, if training should stop, False otherwise.
|
||||
*/
|
||||
@Override
|
||||
public synchronized boolean isEarlyTerminationHit() {
|
||||
return isTerminate.get();
|
||||
}
|
||||
|
||||
private double iterateSample(T element1, T element2, double score) {
|
||||
//prediction: input + bias
|
||||
if (element1.getIndex() < 0 || element1.getIndex() >= syn0.rows())
|
||||
throw new IllegalArgumentException("Illegal index for word " + element1.getLabel());
|
||||
if (element2.getIndex() < 0 || element2.getIndex() >= syn0.rows())
|
||||
throw new IllegalArgumentException("Illegal index for word " + element2.getLabel());
|
||||
|
||||
INDArray w1Vector = syn0.slice(element1.getIndex());
|
||||
INDArray w2Vector = syn0.slice(element2.getIndex());
|
||||
|
||||
|
||||
//w1 * w2 + bias
|
||||
double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
|
||||
prediction += bias.getDouble(element1.getIndex()) + bias.getDouble(element2.getIndex()) - Math.log(score);
|
||||
|
||||
double fDiff = (score > xMax) ? prediction : Math.pow(score / xMax, alpha) * prediction; // Math.pow(Math.min(1.0,(score / maxCount)),xMax);
|
||||
|
||||
// double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
|
||||
|
||||
if (Double.isNaN(fDiff))
|
||||
fDiff = Nd4j.EPS_THRESHOLD;
|
||||
//amount of change
|
||||
double gradient = fDiff * learningRate;
|
||||
|
||||
//note the update step here: the gradient is
|
||||
//the gradient of the OPPOSITE word
|
||||
//for adagrad we will use the index of the word passed in
|
||||
//for the gradient calculation we will use the context vector
|
||||
update(element1, w1Vector, w2Vector, gradient);
|
||||
update(element2, w2Vector, w1Vector, gradient);
|
||||
return 0.5 * fDiff * prediction;
|
||||
}
|
||||
|
||||
private void update(T element1, INDArray wordVector, INDArray contextVector, double gradient) {
|
||||
//gradient for word vectors
|
||||
INDArray grad1 = contextVector.mul(gradient);
|
||||
INDArray update = weightAdaGrad.getGradient(grad1, element1.getIndex(), syn0.shape());
|
||||
|
||||
//update vector
|
||||
wordVector.subi(update);
|
||||
|
||||
double w1Bias = bias.getDouble(element1.getIndex());
|
||||
double biasGradient = biasAdaGrad.getGradient(gradient, element1.getIndex(), bias.shape());
|
||||
double update2 = w1Bias - biasGradient;
|
||||
bias.putScalar(element1.getIndex(), update2);
|
||||
}
|
||||
|
||||
private class GloveCalculationsThread extends Thread implements Runnable {
|
||||
private final int threadId;
|
||||
private final int epochId;
|
||||
// private final AbstractCoOccurrences<T> coOccurrences;
|
||||
private final Iterator<Pair<Pair<T, T>, Double>> coList;
|
||||
|
||||
private final AtomicLong pairsCounter;
|
||||
private final Counter<Integer> errorCounter;
|
||||
|
||||
public GloveCalculationsThread(int epochId, int threadId, @NonNull Iterator<Pair<Pair<T, T>, Double>> pairs,
|
||||
@NonNull AtomicLong pairsCounter, @NonNull Counter<Integer> errorCounter) {
|
||||
this.epochId = epochId;
|
||||
this.threadId = threadId;
|
||||
// this.coOccurrences = coOccurrences;
|
||||
|
||||
this.pairsCounter = pairsCounter;
|
||||
this.errorCounter = errorCounter;
|
||||
|
||||
coList = pairs;
|
||||
|
||||
this.setName("GloVe ELA t." + this.threadId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
// int startPosition = threadId * (coList.size() / workers);
|
||||
// int stopPosition = (threadId + 1) * (coList.size() / workers);
|
||||
// log.info("Total size: [" + coList.size() + "], thread start: [" + startPosition + "], thread stop: [" + stopPosition + "]");
|
||||
while (coList.hasNext()) {
|
||||
|
||||
// now we fetch pairs into batch
|
||||
List<Pair<Pair<T, T>, Double>> pairs = new ArrayList<>();
|
||||
int cnt = 0;
|
||||
while (coList.hasNext() && cnt < batchSize) {
|
||||
pairs.add(coList.next());
|
||||
cnt++;
|
||||
}
|
||||
|
||||
if (shuffle)
|
||||
Collections.shuffle(pairs);
|
||||
|
||||
Iterator<Pair<Pair<T, T>, Double>> iterator = pairs.iterator();
|
||||
|
||||
while (iterator.hasNext()) {
|
||||
// now for each pair do appropriate training
|
||||
Pair<Pair<T, T>, Double> pairDoublePair = iterator.next();
|
||||
|
||||
// That's probably ugly and probably should be improved somehow
|
||||
|
||||
T element1 = pairDoublePair.getFirst().getFirst();
|
||||
T element2 = pairDoublePair.getFirst().getSecond();
|
||||
double weight = pairDoublePair.getSecond(); //coOccurrences.getCoOccurrenceCount(element1, element2);
|
||||
if (weight <= 0) {
|
||||
// log.warn("Skipping pair ("+ element1.getLabel()+", " + element2.getLabel()+")");
|
||||
pairsCounter.incrementAndGet();
|
||||
continue;
|
||||
}
|
||||
|
||||
errorCounter.incrementCount(epochId, iterateSample(element1, element2, weight));
|
||||
if (pairsCounter.incrementAndGet() % 1000000 == 0) {
|
||||
log.info("Processed [" + pairsCounter.get() + "] word pairs so far...");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static class Builder<T extends SequenceElement> {
|
||||
|
||||
protected double xMax = 100.0d;
|
||||
protected double alpha = 0.75d;
|
||||
protected double learningRate = 0.0d;
|
||||
|
||||
protected boolean shuffle = false;
|
||||
protected boolean symmetric = false;
|
||||
protected int maxmemory = 0;
|
||||
|
||||
protected int batchSize = 1000;
|
||||
|
||||
public Builder() {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* This parameter specifies batch size for each thread. Also, if shuffle == TRUE, this batch will be shuffled before processing. Default value: 1000;
|
||||
*
|
||||
* @param batchSize
|
||||
* @return
|
||||
*/
|
||||
public Builder<T> batchSize(int batchSize) {
|
||||
this.batchSize = batchSize;
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Initial learning rate; default 0.05
|
||||
*
|
||||
* @param eta
|
||||
* @return
|
||||
*/
|
||||
public Builder<T> learningRate(double eta) {
|
||||
this.learningRate = eta;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameter in exponent of weighting function; default 0.75
|
||||
*
|
||||
* @param alpha
|
||||
* @return
|
||||
*/
|
||||
public Builder<T> alpha(double alpha) {
|
||||
this.alpha = alpha;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify maximum memory available for CoOccurrence map builder.
|
||||
*
|
||||
* Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm.
|
||||
* Please note: this option won't override -Xmx JVM value.
|
||||
*
|
||||
* @param gbytes memory limit, in gigabytes
|
||||
* @return
|
||||
*/
|
||||
public Builder<T> maxMemory(int gbytes) {
|
||||
this.maxmemory = gbytes;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameter specifying cutoff in weighting function; default 100.0
|
||||
*
|
||||
* @param xMax
|
||||
* @return
|
||||
*/
|
||||
public Builder<T> xMax(double xMax) {
|
||||
this.xMax = xMax;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameter specifying, if cooccurrences list should be shuffled between training epochs
|
||||
*
|
||||
* @param reallyShuffle
|
||||
* @return
|
||||
*/
|
||||
public Builder<T> shuffle(boolean reallyShuffle) {
|
||||
this.shuffle = reallyShuffle;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameters specifying, if cooccurrences list should be build into both directions from any current word.
|
||||
*
|
||||
* @param reallySymmetric
|
||||
* @return
|
||||
*/
|
||||
public Builder<T> symmetric(boolean reallySymmetric) {
|
||||
this.symmetric = reallySymmetric;
|
||||
return this;
|
||||
}
|
||||
|
||||
public GloVe<T> build() {
|
||||
GloVe<T> ret = new GloVe<>();
|
||||
ret.symmetric = this.symmetric;
|
||||
ret.shuffle = this.shuffle;
|
||||
ret.xMax = this.xMax;
|
||||
ret.alpha = this.alpha;
|
||||
ret.learningRate = this.learningRate;
|
||||
ret.maxmemory = this.maxmemory;
|
||||
ret.batchSize = this.batchSize;
|
||||
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -64,7 +64,6 @@ import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
|||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
||||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
|
||||
import org.deeplearning4j.models.fasttext.FastText;
|
||||
import org.deeplearning4j.models.glove.Glove;
|
||||
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
|
||||
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
||||
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceElementFactory;
|
||||
|
@ -142,10 +141,6 @@ import lombok.val;
|
|||
* {@link #readParagraphVectors(String)}
|
||||
* {@link #readParagraphVectors(InputStream)}
|
||||
*
|
||||
* <li>Serializers for GloVe:</li>
|
||||
* {@link #writeWordVectors(Glove, File)}
|
||||
* {@link #writeWordVectors(Glove, String)}
|
||||
* {@link #writeWordVectors(Glove, OutputStream)}
|
||||
*
|
||||
* <li>Adapters</li>
|
||||
* {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
|
||||
|
@ -154,7 +149,6 @@ import lombok.val;
|
|||
* {@link #loadTxt(InputStream)}
|
||||
*
|
||||
* <li>Serializers to tSNE format</li>
|
||||
* {@link #writeTsneFormat(Glove, INDArray, File)}
|
||||
* {@link #writeTsneFormat(Word2Vec, INDArray, File)}
|
||||
*
|
||||
* <li>FastText serializer:</li>
|
||||
|
@ -974,7 +968,7 @@ public class WordVectorSerializer {
|
|||
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
|
||||
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
|
||||
// first we load syn0
|
||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors));
|
||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors)); //Note stream is closed in loadTxt
|
||||
InMemoryLookupTable lookupTable = pair.getFirst();
|
||||
lookupTable.setNegative(configuration.getNegative());
|
||||
if (configuration.getNegative() > 0)
|
||||
|
@ -1161,48 +1155,6 @@ public class WordVectorSerializer {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This method saves GloVe model to the given output stream.
|
||||
*
|
||||
* @param vectors GloVe model to be saved
|
||||
* @param file path where model should be saved to
|
||||
*/
|
||||
public static void writeWordVectors(@NonNull Glove vectors, @NonNull File file) {
|
||||
try (BufferedOutputStream fos = new BufferedOutputStream(new FileOutputStream(file))) {
|
||||
writeWordVectors(vectors, fos);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This method saves GloVe model to the given output stream.
|
||||
*
|
||||
* @param vectors GloVe model to be saved
|
||||
* @param path path where model should be saved to
|
||||
*/
|
||||
public static void writeWordVectors(@NonNull Glove vectors, @NonNull String path) {
|
||||
try (BufferedOutputStream fos = new BufferedOutputStream(new FileOutputStream(path))) {
|
||||
writeWordVectors(vectors, fos);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This method saves GloVe model to the given OutputStream
|
||||
*
|
||||
* @param vectors GloVe model to be saved
|
||||
* @param stream OutputStream where model should be saved to
|
||||
*/
|
||||
public static void writeWordVectors(@NonNull Glove vectors, @NonNull OutputStream stream) {
|
||||
try {
|
||||
writeWordVectors(vectors.lookupTable(), stream);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This method saves paragraph vectors to the given output stream.
|
||||
*
|
||||
|
@ -1655,7 +1607,7 @@ public class WordVectorSerializer {
|
|||
*/
|
||||
@Deprecated
|
||||
public static WordVectors loadTxtVectors(File vectorsFile) throws IOException {
|
||||
FileInputStream fileInputStream = new FileInputStream(vectorsFile);
|
||||
FileInputStream fileInputStream = new FileInputStream(vectorsFile); //Note stream is closed in loadTxt
|
||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(fileInputStream);
|
||||
return fromPair(pair);
|
||||
}
|
||||
|
@ -1877,43 +1829,6 @@ public class WordVectorSerializer {
|
|||
return fromPair(Pair.makePair((InMemoryLookupTable) lookupTable, (VocabCache) cache));
|
||||
}
|
||||
|
||||
/**
|
||||
* Write the tsne format
|
||||
*
|
||||
* @param vec the word vectors to use for labeling
|
||||
* @param tsne the tsne array to write
|
||||
* @param csv the file to use
|
||||
* @throws Exception
|
||||
*/
|
||||
public static void writeTsneFormat(Glove vec, INDArray tsne, File csv) throws Exception {
|
||||
try (BufferedWriter write = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(csv), StandardCharsets.UTF_8))) {
|
||||
int words = 0;
|
||||
InMemoryLookupCache l = (InMemoryLookupCache) vec.vocab();
|
||||
for (String word : vec.vocab().words()) {
|
||||
if (word == null) {
|
||||
continue;
|
||||
}
|
||||
StringBuilder sb = new StringBuilder();
|
||||
INDArray wordVector = tsne.getRow(l.wordFor(word).getIndex());
|
||||
for (int j = 0; j < wordVector.length(); j++) {
|
||||
sb.append(wordVector.getDouble(j));
|
||||
if (j < wordVector.length() - 1) {
|
||||
sb.append(",");
|
||||
}
|
||||
}
|
||||
sb.append(",");
|
||||
sb.append(word.replaceAll(" ", WHITESPACE_REPLACEMENT));
|
||||
sb.append(" ");
|
||||
|
||||
sb.append("\n");
|
||||
write.write(sb.toString());
|
||||
|
||||
}
|
||||
|
||||
log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Write the tsne format
|
||||
*
|
||||
|
|
|
@ -1,652 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.deeplearning4j.models.glove.count.*;
|
||||
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
|
||||
import org.deeplearning4j.models.sequencevectors.iterators.FilteredSequenceIterator;
|
||||
import org.deeplearning4j.models.sequencevectors.iterators.SynchronizedSequenceIterator;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator;
|
||||
import org.deeplearning4j.common.util.DL4JFileUtils;
|
||||
import org.nd4j.common.util.ThreadUtils;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
|
||||
/**
|
||||
* This class implements building cooccurrence map for abstract training corpus.
|
||||
* However it's performance rather low, due to exsessive IO that happens in ShadowCopyThread
|
||||
*
|
||||
* PLEASE NOTE: Current implementation involves massive IO, and it should be rewritter as soon as ND4j gets sparse arrays support
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class AbstractCoOccurrences<T extends SequenceElement> implements Serializable {
|
||||
|
||||
protected boolean symmetric;
|
||||
protected int windowSize;
|
||||
protected VocabCache<T> vocabCache;
|
||||
protected SequenceIterator<T> sequenceIterator;
|
||||
|
||||
// please note, we need enough room for ShadowCopy thread, that's why -1 there
|
||||
protected int workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
|
||||
|
||||
// target file, where text with cooccurrencies should be saved
|
||||
protected File targetFile;
|
||||
|
||||
protected ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
|
||||
|
||||
protected long memory_threshold = 0;
|
||||
|
||||
private ShadowCopyThread shadowThread;
|
||||
|
||||
// private Counter<Integer> sentenceOccurrences = Util.parallelCounter();
|
||||
//private CounterMap<T, T> coOccurrenceCounts = Util.parallelCounterMap();
|
||||
private volatile CountMap<T> coOccurrenceCounts = new CountMap<>();
|
||||
//private Counter<Integer> occurrenceAllocations = Util.parallelCounter();
|
||||
//private List<Pair<T, T>> coOccurrences;
|
||||
private AtomicLong processedSequences = new AtomicLong(0);
|
||||
|
||||
|
||||
protected static final Logger logger = LoggerFactory.getLogger(AbstractCoOccurrences.class);
|
||||
|
||||
// this method should be private, to avoid non-configured instantiation
|
||||
private AbstractCoOccurrences() {}
|
||||
|
||||
/**
|
||||
* This method returns cooccurrence distance weights for two SequenceElements
|
||||
*
|
||||
* @param element1
|
||||
* @param element2
|
||||
* @return distance weight
|
||||
*/
|
||||
public double getCoOccurrenceCount(@NonNull T element1, @NonNull T element2) {
|
||||
return coOccurrenceCounts.getCount(element1, element2);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns estimated memory footrpint, based on current CountMap content
|
||||
* @return
|
||||
*/
|
||||
protected long getMemoryFootprint() {
|
||||
// TODO: implement this method. It should return approx. memory used by appropriate CountMap
|
||||
try {
|
||||
lock.readLock().lock();
|
||||
return ((long) coOccurrenceCounts.size()) * 24L * 5L;
|
||||
} finally {
|
||||
lock.readLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This memory returns memory threshold, defined as 1/2 of memory allowed for allocation
|
||||
* @return
|
||||
*/
|
||||
protected long getMemoryThreshold() {
|
||||
return memory_threshold / 2L;
|
||||
}
|
||||
|
||||
public void fit() {
|
||||
shadowThread = new ShadowCopyThread();
|
||||
shadowThread.start();
|
||||
|
||||
// we should reset iterator before counting cooccurrences
|
||||
sequenceIterator.reset();
|
||||
|
||||
List<CoOccurrencesCalculatorThread> threads = new ArrayList<>();
|
||||
for (int x = 0; x < workers; x++) {
|
||||
threads.add(x, new CoOccurrencesCalculatorThread(x, new FilteredSequenceIterator<>(
|
||||
new SynchronizedSequenceIterator<>(sequenceIterator), vocabCache), processedSequences));
|
||||
threads.get(x).start();
|
||||
}
|
||||
|
||||
for (int x = 0; x < workers; x++) {
|
||||
try {
|
||||
threads.get(x).join();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
shadowThread.finish();
|
||||
logger.info("CoOccurrences map was built.");
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* This method returns iterator with elements pairs and their weights. Resulting iterator is safe to use in multi-threaded environment.
|
||||
*
|
||||
* Developer's note: thread safety on received iterator is delegated to PrefetchedSentenceIterator
|
||||
* @return
|
||||
*/
|
||||
public Iterator<Pair<Pair<T, T>, Double>> iterator() {
|
||||
final SentenceIterator iterator;
|
||||
|
||||
try {
|
||||
iterator = new SynchronizedSentenceIterator(
|
||||
new PrefetchingSentenceIterator.Builder(new BasicLineIterator(targetFile))
|
||||
.setFetchSize(500000).build());
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("Target file was not found on last stage!");
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return new Iterator<Pair<Pair<T, T>, Double>>() {
|
||||
/*
|
||||
iterator should be built on top of current text file with all pairs
|
||||
*/
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return iterator.hasNext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<Pair<T, T>, Double> next() {
|
||||
String line = iterator.nextSentence();
|
||||
String[] strings = line.split(" ");
|
||||
|
||||
T element1 = vocabCache.elementAtIndex(Integer.valueOf(strings[0]));
|
||||
T element2 = vocabCache.elementAtIndex(Integer.valueOf(strings[1]));
|
||||
Double weight = Double.valueOf(strings[2]);
|
||||
|
||||
return new Pair<>(new Pair<>(element1, element2), weight);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void remove() {
|
||||
throw new UnsupportedOperationException("remove() method can't be supported on read-only interface");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public static class Builder<T extends SequenceElement> {
|
||||
|
||||
protected boolean symmetric;
|
||||
protected int windowSize = 5;
|
||||
protected VocabCache<T> vocabCache;
|
||||
protected SequenceIterator<T> sequenceIterator;
|
||||
protected int workers = Runtime.getRuntime().availableProcessors();
|
||||
protected File target;
|
||||
protected long maxmemory = Runtime.getRuntime().maxMemory();
|
||||
|
||||
public Builder() {
|
||||
|
||||
}
|
||||
|
||||
public Builder<T> symmetric(boolean reallySymmetric) {
|
||||
this.symmetric = reallySymmetric;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder<T> windowSize(int windowSize) {
|
||||
this.windowSize = windowSize;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder<T> vocabCache(@NonNull VocabCache<T> cache) {
|
||||
this.vocabCache = cache;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder<T> iterate(@NonNull SequenceIterator<T> iterator) {
|
||||
this.sequenceIterator = new SynchronizedSequenceIterator<>(iterator);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder<T> workers(int numWorkers) {
|
||||
this.workers = numWorkers;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify maximum memory available for CoOccurrence map builder.
|
||||
*
|
||||
* Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm.
|
||||
* Please note: this option won't override -Xmx JVM value.
|
||||
*
|
||||
* @param gbytes memory available, in GigaBytes
|
||||
* @return
|
||||
*/
|
||||
public Builder<T> maxMemory(int gbytes) {
|
||||
if (gbytes > 0) {
|
||||
this.maxmemory = Math.max(gbytes - 1, 1) * 1024 * 1024 * 1024L;
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Path to save cooccurrence map after construction.
|
||||
* If targetFile is not specified, temporary file will be used.
|
||||
*
|
||||
* @param path
|
||||
* @return
|
||||
*/
|
||||
public Builder<T> targetFile(@NonNull String path) {
|
||||
this.targetFile(new File(path));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Path to save cooccurrence map after construction.
|
||||
* If targetFile is not specified, temporary file will be used.
|
||||
*
|
||||
* @param file
|
||||
* @return
|
||||
*/
|
||||
public Builder<T> targetFile(@NonNull File file) {
|
||||
this.target = file;
|
||||
return this;
|
||||
}
|
||||
|
||||
public AbstractCoOccurrences<T> build() {
|
||||
AbstractCoOccurrences<T> ret = new AbstractCoOccurrences<>();
|
||||
ret.sequenceIterator = this.sequenceIterator;
|
||||
ret.windowSize = this.windowSize;
|
||||
ret.vocabCache = this.vocabCache;
|
||||
ret.symmetric = this.symmetric;
|
||||
ret.workers = this.workers;
|
||||
|
||||
if (this.maxmemory < 1) {
|
||||
this.maxmemory = Runtime.getRuntime().maxMemory();
|
||||
}
|
||||
ret.memory_threshold = this.maxmemory;
|
||||
|
||||
|
||||
logger.info("Actual memory limit: [" + this.maxmemory + "]");
|
||||
|
||||
// use temp file, if no target file was specified
|
||||
try {
|
||||
if (this.target == null) {
|
||||
this.target = DL4JFileUtils.createTempFile("cooccurrence", "map");
|
||||
}
|
||||
this.target.deleteOnExit();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
ret.targetFile = this.target;
|
||||
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
private class CoOccurrencesCalculatorThread extends Thread implements Runnable {
|
||||
|
||||
private final SequenceIterator<T> iterator;
|
||||
private final AtomicLong sequenceCounter;
|
||||
private int threadId;
|
||||
|
||||
public CoOccurrencesCalculatorThread(int threadId, @NonNull SequenceIterator<T> iterator,
|
||||
@NonNull AtomicLong sequenceCounter) {
|
||||
this.iterator = iterator;
|
||||
this.sequenceCounter = sequenceCounter;
|
||||
this.threadId = threadId;
|
||||
|
||||
this.setName("CoOccurrencesCalculatorThread " + threadId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
while (iterator.hasMoreSequences()) {
|
||||
Sequence<T> sequence = iterator.nextSequence();
|
||||
|
||||
List<String> tokens = new ArrayList<>(sequence.asLabels());
|
||||
// logger.info("Tokens size: " + tokens.size());
|
||||
for (int x = 0; x < sequence.getElements().size(); x++) {
|
||||
int wordIdx = vocabCache.indexOf(tokens.get(x));
|
||||
if (wordIdx < 0) {
|
||||
continue;
|
||||
}
|
||||
String w1 = vocabCache.wordFor(tokens.get(x)).getLabel();
|
||||
|
||||
// THIS iS SAFE TO REMOVE, NO CHANCE WE'll HAVE UNK WORD INSIDE SEQUENCE
|
||||
/*if(w1.equals(Glove.UNK))
|
||||
continue;
|
||||
*/
|
||||
|
||||
int windowStop = Math.min(x + windowSize + 1, tokens.size());
|
||||
for (int j = x; j < windowStop; j++) {
|
||||
int otherWord = vocabCache.indexOf(tokens.get(j));
|
||||
if (otherWord < 0) {
|
||||
continue;
|
||||
}
|
||||
String w2 = vocabCache.wordFor(tokens.get(j)).getLabel();
|
||||
|
||||
if (w2.equals(Glove.DEFAULT_UNK) || otherWord == wordIdx) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
T tokenX = vocabCache.wordFor(tokens.get(x));
|
||||
T tokenJ = vocabCache.wordFor(tokens.get(j));
|
||||
double nWeight = 1.0 / (j - x + Nd4j.EPS_THRESHOLD);
|
||||
|
||||
while (getMemoryFootprint() >= getMemoryThreshold()) {
|
||||
shadowThread.invoke();
|
||||
/*lock.readLock().lock();
|
||||
int size = coOccurrenceCounts.size();
|
||||
lock.readLock().unlock();
|
||||
*/
|
||||
if (threadId == 0) {
|
||||
logger.debug("Memory consuimption > threshold: {footrpint: [" + getMemoryFootprint()
|
||||
+ "], threshold: [" + getMemoryThreshold() + "] }");
|
||||
}
|
||||
ThreadUtils.uncheckedSleep(10000);
|
||||
}
|
||||
/*
|
||||
if (getMemoryFootprint() == 0) {
|
||||
logger.info("Zero size!");
|
||||
}
|
||||
*/
|
||||
|
||||
try {
|
||||
lock.readLock().lock();
|
||||
if (wordIdx < otherWord) {
|
||||
coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight);
|
||||
if (symmetric) {
|
||||
coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight);
|
||||
}
|
||||
} else {
|
||||
coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight);
|
||||
|
||||
if (symmetric) {
|
||||
coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
lock.readLock().unlock();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sequenceCounter.incrementAndGet();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This class is designed to provide shadow copy functionality for CoOccurence maps, since with proper corpus size you can't fit such a map into memory
|
||||
*
|
||||
*/
|
||||
private class ShadowCopyThread extends Thread implements Runnable {
|
||||
|
||||
private AtomicBoolean isFinished = new AtomicBoolean(false);
|
||||
private AtomicBoolean isTerminate = new AtomicBoolean(false);
|
||||
private AtomicBoolean isInvoked = new AtomicBoolean(false);
|
||||
private AtomicBoolean shouldInvoke = new AtomicBoolean(false);
|
||||
|
||||
// file that contains resuts from previous runs
|
||||
private File[] tempFiles;
|
||||
private RoundCount counter;
|
||||
|
||||
public ShadowCopyThread() {
|
||||
try {
|
||||
|
||||
counter = new RoundCount(1);
|
||||
tempFiles = new File[2];
|
||||
|
||||
tempFiles[0] = DL4JFileUtils.createTempFile("aco", "tmp");
|
||||
tempFiles[1] = DL4JFileUtils.createTempFile("aco", "tmp");
|
||||
|
||||
tempFiles[0].deleteOnExit();
|
||||
tempFiles[1].deleteOnExit();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
this.setName("ACO ShadowCopy thread");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
/*
|
||||
Basic idea is pretty simple: run quetly, untill memory gets filled up to some high volume.
|
||||
As soon as this happens - execute shadow copy.
|
||||
*/
|
||||
while (!isFinished.get() && !isTerminate.get()) {
|
||||
// check used memory. if memory use below threshold - sleep for a while. if above threshold - invoke copier
|
||||
|
||||
if (getMemoryFootprint() > getMemoryThreshold() || (shouldInvoke.get() && !isInvoked.get())) {
|
||||
// we'll just invoke copier, nothing else
|
||||
shouldInvoke.compareAndSet(true, false);
|
||||
invokeBlocking();
|
||||
} else {
|
||||
/*
|
||||
commented and left here for future debugging purposes, if needed
|
||||
|
||||
//lock.readLock().lock();
|
||||
//int size = coOccurrenceCounts.size();
|
||||
//lock.readLock().unlock();
|
||||
//logger.info("Current memory situation: {size: [" +size+ "], footprint: [" + getMemoryFootprint()+"], threshold: ["+ getMemoryThreshold() +"]}");
|
||||
*/
|
||||
ThreadUtils.uncheckedSleep(1000);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This methods advises shadow copy process to start
|
||||
*/
|
||||
public void invoke() {
|
||||
shouldInvoke.compareAndSet(false, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* This methods dumps cooccurrence map into save file.
|
||||
* Please note: this method is synchronized and will block, until complete
|
||||
*/
|
||||
public synchronized void invokeBlocking() {
|
||||
if (getMemoryFootprint() < getMemoryThreshold() && !isFinished.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
int numberOfLinesSaved = 0;
|
||||
|
||||
isInvoked.set(true);
|
||||
|
||||
logger.debug("Memory purge started.");
|
||||
|
||||
/*
|
||||
Basic plan:
|
||||
1. Open temp file
|
||||
2. Read that file line by line
|
||||
3. For each read line do synchronization in memory > new file direction
|
||||
*/
|
||||
|
||||
counter.tick();
|
||||
|
||||
CountMap<T> localMap;
|
||||
try {
|
||||
// in any given moment there's going to be only 1 WriteLock, due to invokeBlocking() being synchronized call
|
||||
lock.writeLock().lock();
|
||||
|
||||
|
||||
|
||||
// obtain local copy of CountMap
|
||||
localMap = coOccurrenceCounts;
|
||||
|
||||
// set new CountMap, and release write lock
|
||||
coOccurrenceCounts = new CountMap<>();
|
||||
} finally {
|
||||
lock.writeLock().unlock();
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
File file = null;
|
||||
if (!isFinished.get()) {
|
||||
file = tempFiles[counter.previous()];
|
||||
} else
|
||||
file = targetFile;
|
||||
|
||||
|
||||
// PrintWriter pw = new PrintWriter(file);
|
||||
|
||||
int linesRead = 0;
|
||||
|
||||
logger.debug("Saving to: [" + counter.get() + "], Reading from: [" + counter.previous() + "]");
|
||||
CoOccurenceReader<T> reader =
|
||||
new BinaryCoOccurrenceReader<>(tempFiles[counter.previous()], vocabCache, localMap);
|
||||
CoOccurrenceWriter<T> writer = (isFinished.get()) ? new ASCIICoOccurrenceWriter<T>(targetFile)
|
||||
: new BinaryCoOccurrenceWriter<T>(tempFiles[counter.get()]);
|
||||
while (reader.hasMoreObjects()) {
|
||||
CoOccurrenceWeight<T> line = reader.nextObject();
|
||||
|
||||
if (line != null) {
|
||||
writer.writeObject(line);
|
||||
numberOfLinesSaved++;
|
||||
linesRead++;
|
||||
}
|
||||
}
|
||||
reader.finish();
|
||||
|
||||
logger.debug("Lines read: [" + linesRead + "]");
|
||||
|
||||
//now, we can dump the rest of elements, which were not presented in existing dump
|
||||
Iterator<Pair<T, T>> iterator = localMap.getPairIterator();
|
||||
while (iterator.hasNext()) {
|
||||
Pair<T, T> pair = iterator.next();
|
||||
double mWeight = localMap.getCount(pair);
|
||||
CoOccurrenceWeight<T> object = new CoOccurrenceWeight<>();
|
||||
object.setElement1(pair.getFirst());
|
||||
object.setElement2(pair.getSecond());
|
||||
object.setWeight(mWeight);
|
||||
|
||||
writer.writeObject(object);
|
||||
|
||||
numberOfLinesSaved++;
|
||||
// if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
|
||||
}
|
||||
|
||||
writer.finish();
|
||||
|
||||
/*
|
||||
SentenceIterator sIterator = new PrefetchingSentenceIterator.Builder(new BasicLineIterator(tempFiles[counter.get()]))
|
||||
.setFetchSize(500000)
|
||||
.build();
|
||||
|
||||
|
||||
int linesRead = 0;
|
||||
while (sIterator.hasNext()) {
|
||||
//List<Writable> list = new ArrayList<>(reader.next());
|
||||
String sentence = sIterator.nextSentence();
|
||||
if (sentence == null || sentence.isEmpty()) continue;
|
||||
String[] strings = sentence.split(" ");
|
||||
|
||||
|
||||
// first two elements are integers - vocab indexes
|
||||
//T element1 = vocabCache.wordFor(vocabCache.wordAtIndex(list.get(0).toInt()));
|
||||
//T element2 = vocabCache.wordFor(vocabCache.wordAtIndex(list.get(1).toInt()));
|
||||
T element1 = vocabCache.elementAtIndex(Integer.valueOf(strings[0]));
|
||||
T element2 = vocabCache.elementAtIndex(Integer.valueOf(strings[1]));
|
||||
|
||||
// getting third element, previously stored weight
|
||||
double sWeight = Double.valueOf(strings[2]); // list.get(2).toDouble();
|
||||
|
||||
// now, since we have both elements ready, we can check this pair against inmemory map
|
||||
double mWeight = localMap.getCount(element1, element2);
|
||||
if (mWeight <= 0) {
|
||||
// this means we have no such pair in memory, so we'll do nothing to sWeight
|
||||
} else {
|
||||
// since we have new weight value in memory, we should update sWeight value before moving it off memory
|
||||
sWeight += mWeight;
|
||||
|
||||
// original pair can be safely removed from CountMap
|
||||
localMap.removePair(element1,element2);
|
||||
}
|
||||
|
||||
StringBuilder builder = new StringBuilder().append(element1.getIndex()).append(" ").append(element2.getIndex()).append(" ").append(sWeight);
|
||||
pw.println(builder.toString());
|
||||
numberOfLinesSaved++;
|
||||
linesRead++;
|
||||
|
||||
// if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
|
||||
// if (linesRead % 100000 == 0) logger.info("Lines read: [" + linesRead +"]");
|
||||
}
|
||||
*/
|
||||
/*
|
||||
logger.info("Lines read: [" + linesRead + "]");
|
||||
|
||||
//now, we can dump the rest of elements, which were not presented in existing dump
|
||||
Iterator<Pair<T, T>> iterator = localMap.getPairIterator();
|
||||
while (iterator.hasNext()) {
|
||||
Pair<T, T> pair = iterator.next();
|
||||
double mWeight = localMap.getCount(pair);
|
||||
|
||||
StringBuilder builder = new StringBuilder().append(pair.getFirst().getIndex()).append(" ").append(pair.getFirst().getIndex()).append(" ").append(mWeight);
|
||||
pw.println(builder.toString());
|
||||
numberOfLinesSaved++;
|
||||
|
||||
// if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]");
|
||||
}
|
||||
|
||||
pw.flush();
|
||||
pw.close();
|
||||
|
||||
*/
|
||||
|
||||
// just a hint for gc
|
||||
localMap = null;
|
||||
//sIterator.finish();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
logger.info("Number of word pairs saved so far: [" + numberOfLinesSaved + "]");
|
||||
isInvoked.set(false);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method provides soft finish ability for shadow copy process.
|
||||
* Please note: it's blocking call, since it requires for final merge.
|
||||
*/
|
||||
public void finish() {
|
||||
if (this.isFinished.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.isFinished.set(true);
|
||||
invokeBlocking();
|
||||
}
|
||||
|
||||
/**
|
||||
* This method provides hard fiinish ability for shadow copy process
|
||||
*/
|
||||
public void terminate() {
|
||||
this.isTerminate.set(true);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,444 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||
import org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe;
|
||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
|
||||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
||||
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
||||
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
|
||||
import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener;
|
||||
import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
|
||||
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
|
||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.deeplearning4j.text.documentiterator.DocumentIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.StreamLineIterator;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* GlobalVectors standalone implementation for DL4j.
|
||||
* Based on original Stanford GloVe <a href="http://www-nlp.stanford.edu/pubs/glove.pdf">http://www-nlp.stanford.edu/pubs/glove.pdf</a>
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class Glove extends SequenceVectors<VocabWord> {
|
||||
|
||||
protected Glove() {
|
||||
|
||||
}
|
||||
|
||||
public static class Builder extends SequenceVectors.Builder<VocabWord> {
|
||||
private double xMax;
|
||||
private boolean shuffle;
|
||||
private boolean symmetric;
|
||||
protected double alpha = 0.75d;
|
||||
private int maxmemory = (int) (Runtime.getRuntime().totalMemory() / 1024 / 1024 / 1024);
|
||||
|
||||
protected TokenizerFactory tokenFactory;
|
||||
protected SentenceIterator sentenceIterator;
|
||||
protected DocumentIterator documentIterator;
|
||||
|
||||
public Builder() {
|
||||
super();
|
||||
}
|
||||
|
||||
|
||||
public Builder(@NonNull VectorsConfiguration configuration) {
|
||||
super(configuration);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* This method has no effect for GloVe
|
||||
*
|
||||
* @param vec existing WordVectors model
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Builder useExistingWordVectors(@NonNull WordVectors vec) {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder iterate(@NonNull SequenceIterator<VocabWord> iterator) {
|
||||
super.iterate(iterator);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Specifies minibatch size for training process.
|
||||
*
|
||||
* @param batchSize
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Builder batchSize(int batchSize) {
|
||||
super.batchSize(batchSize);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ierations and epochs are the same in GloVe implementation.
|
||||
*
|
||||
* @param iterations
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Builder iterations(int iterations) {
|
||||
super.epochs(iterations);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the number of iteration over training corpus during training
|
||||
*
|
||||
* @param numEpochs
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Builder epochs(int numEpochs) {
|
||||
super.epochs(numEpochs);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder useAdaGrad(boolean reallyUse) {
|
||||
super.useAdaGrad(true);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder layerSize(int layerSize) {
|
||||
super.layerSize(layerSize);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder learningRate(double learningRate) {
|
||||
super.learningRate(learningRate);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets minimum word frequency during vocabulary mastering.
|
||||
* Please note: this option is ignored, if vocabulary is built outside of GloVe
|
||||
*
|
||||
* @param minWordFrequency
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Builder minWordFrequency(int minWordFrequency) {
|
||||
super.minWordFrequency(minWordFrequency);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder minLearningRate(double minLearningRate) {
|
||||
super.minLearningRate(minLearningRate);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder resetModel(boolean reallyReset) {
|
||||
super.resetModel(reallyReset);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder vocabCache(@NonNull VocabCache<VocabWord> vocabCache) {
|
||||
super.vocabCache(vocabCache);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder lookupTable(@NonNull WeightLookupTable<VocabWord> lookupTable) {
|
||||
super.lookupTable(lookupTable);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Deprecated
|
||||
public Builder sampling(double sampling) {
|
||||
super.sampling(sampling);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Deprecated
|
||||
public Builder negativeSample(double negative) {
|
||||
super.negativeSample(negative);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder stopWords(@NonNull List<String> stopList) {
|
||||
super.stopWords(stopList);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder trainElementsRepresentation(boolean trainElements) {
|
||||
super.trainElementsRepresentation(true);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Deprecated
|
||||
public Builder trainSequencesRepresentation(boolean trainSequences) {
|
||||
super.trainSequencesRepresentation(false);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder stopWords(@NonNull Collection<VocabWord> stopList) {
|
||||
super.stopWords(stopList);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder windowSize(int windowSize) {
|
||||
super.windowSize(windowSize);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder seed(long randomSeed) {
|
||||
super.seed(randomSeed);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder workers(int numWorkers) {
|
||||
super.workers(numWorkers);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets TokenizerFactory to be used for training
|
||||
*
|
||||
* @param tokenizerFactory
|
||||
* @return
|
||||
*/
|
||||
public Builder tokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
|
||||
this.tokenFactory = tokenizerFactory;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameter specifying cutoff in weighting function; default 100.0
|
||||
*
|
||||
* @param xMax
|
||||
* @return
|
||||
*/
|
||||
public Builder xMax(double xMax) {
|
||||
this.xMax = xMax;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameters specifying, if cooccurrences list should be build into both directions from any current word.
|
||||
*
|
||||
* @param reallySymmetric
|
||||
* @return
|
||||
*/
|
||||
public Builder symmetric(boolean reallySymmetric) {
|
||||
this.symmetric = reallySymmetric;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameter specifying, if cooccurrences list should be shuffled between training epochs
|
||||
*
|
||||
* @param reallyShuffle
|
||||
* @return
|
||||
*/
|
||||
public Builder shuffle(boolean reallyShuffle) {
|
||||
this.shuffle = reallyShuffle;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method has no effect for ParagraphVectors
|
||||
*
|
||||
* @param windows
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Builder useVariableWindow(int... windows) {
|
||||
// no-op
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameter in exponent of weighting function; default 0.75
|
||||
*
|
||||
* @param alpha
|
||||
* @return
|
||||
*/
|
||||
public Builder alpha(double alpha) {
|
||||
this.alpha = alpha;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder iterate(@NonNull SentenceIterator iterator) {
|
||||
this.sentenceIterator = iterator;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder iterate(@NonNull DocumentIterator iterator) {
|
||||
this.sentenceIterator = new StreamLineIterator.Builder(iterator).setFetchSize(100).build();
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets ModelUtils that gonna be used as provider for utility methods: similarity(), wordsNearest(), accuracy(), etc
|
||||
*
|
||||
* @param modelUtils model utils to be used
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Builder modelUtils(@NonNull ModelUtils<VocabWord> modelUtils) {
|
||||
super.modelUtils(modelUtils);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method sets VectorsListeners for this SequenceVectors model
|
||||
*
|
||||
* @param vectorsListeners
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Builder setVectorsListeners(@NonNull Collection<VectorsListener<VocabWord>> vectorsListeners) {
|
||||
super.setVectorsListeners(vectorsListeners);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify maximum memory available for CoOccurrence map builder.
|
||||
*
|
||||
* Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm.
|
||||
* Please note: this option won't override -Xmx JVM value.
|
||||
*
|
||||
* @param gbytes memory limit, in gigabytes
|
||||
* @return
|
||||
*/
|
||||
public Builder maxMemory(int gbytes) {
|
||||
this.maxmemory = gbytes;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify SequenceElement that will be used as UNK element, if UNK is used
|
||||
*
|
||||
* @param element
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Builder unknownElement(VocabWord element) {
|
||||
super.unknownElement(element);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify, if UNK word should be used internally
|
||||
*
|
||||
* @param reallyUse
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Builder useUnknown(boolean reallyUse) {
|
||||
super.useUnknown(reallyUse);
|
||||
if (this.unknownElement == null) {
|
||||
this.unknownElement(new VocabWord(1.0, Glove.DEFAULT_UNK));
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public Glove build() {
|
||||
presetTables();
|
||||
|
||||
Glove ret = new Glove();
|
||||
|
||||
|
||||
// hardcoded value for glove
|
||||
|
||||
if (sentenceIterator != null) {
|
||||
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(sentenceIterator)
|
||||
.tokenizerFactory(tokenFactory).build();
|
||||
this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build();
|
||||
}
|
||||
|
||||
|
||||
ret.trainElementsVectors = true;
|
||||
ret.trainSequenceVectors = false;
|
||||
ret.useAdeGrad = true;
|
||||
this.useAdaGrad = true;
|
||||
|
||||
ret.learningRate.set(this.learningRate);
|
||||
ret.resetModel = this.resetModel;
|
||||
ret.batchSize = this.batchSize;
|
||||
ret.iterator = this.iterator;
|
||||
ret.numEpochs = this.numEpochs;
|
||||
ret.numIterations = this.iterations;
|
||||
ret.layerSize = this.layerSize;
|
||||
|
||||
ret.useUnknown = this.useUnknown;
|
||||
ret.unknownElement = this.unknownElement;
|
||||
|
||||
|
||||
|
||||
this.configuration.setLearningRate(this.learningRate);
|
||||
this.configuration.setLayersSize(layerSize);
|
||||
this.configuration.setHugeModelExpected(hugeModelExpected);
|
||||
this.configuration.setWindow(window);
|
||||
this.configuration.setMinWordFrequency(minWordFrequency);
|
||||
this.configuration.setIterations(iterations);
|
||||
this.configuration.setSeed(seed);
|
||||
this.configuration.setBatchSize(batchSize);
|
||||
this.configuration.setLearningRateDecayWords(learningRateDecayWords);
|
||||
this.configuration.setMinLearningRate(minLearningRate);
|
||||
this.configuration.setSampling(this.sampling);
|
||||
this.configuration.setUseAdaGrad(useAdaGrad);
|
||||
this.configuration.setNegative(negative);
|
||||
this.configuration.setEpochs(this.numEpochs);
|
||||
|
||||
|
||||
ret.configuration = this.configuration;
|
||||
|
||||
ret.lookupTable = this.lookupTable;
|
||||
ret.vocab = this.vocabCache;
|
||||
ret.modelUtils = this.modelUtils;
|
||||
ret.eventListeners = this.vectorsListeners;
|
||||
|
||||
|
||||
ret.elementsLearningAlgorithm = new GloVe.Builder<VocabWord>().learningRate(this.learningRate)
|
||||
.shuffle(this.shuffle).symmetric(this.symmetric).xMax(this.xMax).alpha(this.alpha)
|
||||
.maxMemory(maxmemory).build();
|
||||
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,334 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove;
|
||||
|
||||
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.apache.commons.io.LineIterator;
|
||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.legacy.AdaGrad;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
/**
|
||||
* Glove lookup table
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
// Deprecated due to logic being pulled off WeightLookupTable classes into LearningAlgorithm interfaces for better code.
|
||||
@Deprecated
|
||||
public class GloveWeightLookupTable<T extends SequenceElement> extends InMemoryLookupTable<T> {
|
||||
|
||||
|
||||
private AdaGrad weightAdaGrad;
|
||||
private AdaGrad biasAdaGrad;
|
||||
private INDArray bias;
|
||||
//also known as alpha
|
||||
private double xMax = 0.75;
|
||||
private double maxCount = 100;
|
||||
|
||||
|
||||
public GloveWeightLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, Random gen,
|
||||
double negative, double xMax, double maxCount) {
|
||||
super(vocab, vectorLength, useAdaGrad, lr, gen, negative);
|
||||
this.xMax = xMax;
|
||||
this.maxCount = maxCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void resetWeights(boolean reset) {
|
||||
if (rng == null)
|
||||
this.rng = Nd4j.getRandom();
|
||||
|
||||
//note the +2 which is the unk vocab word and the bias
|
||||
if (syn0 == null || reset) {
|
||||
syn0 = Nd4j.rand(new int[] {vocab.numWords() + 1, vectorLength}, rng).subi(0.5).divi((double) vectorLength);
|
||||
INDArray randUnk = Nd4j.rand(1, vectorLength, rng).subi(0.5).divi(vectorLength);
|
||||
putVector(Word2Vec.DEFAULT_UNK, randUnk);
|
||||
}
|
||||
if (weightAdaGrad == null || reset) {
|
||||
weightAdaGrad = new AdaGrad(new long[]{vocab.numWords() + 1, vectorLength}, lr.get());
|
||||
}
|
||||
|
||||
|
||||
//right after unknown
|
||||
if (bias == null || reset)
|
||||
bias = Nd4j.create(syn0.rows());
|
||||
|
||||
if (biasAdaGrad == null || reset) {
|
||||
biasAdaGrad = new AdaGrad(bias.shape(), lr.get());
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset the weights of the cache
|
||||
*/
|
||||
@Override
|
||||
public void resetWeights() {
|
||||
resetWeights(true);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* glove iteration
|
||||
* @param w1 the first word
|
||||
* @param w2 the second word
|
||||
* @param score the weight learned for the particular co occurrences
|
||||
*/
|
||||
public double iterateSample(T w1, T w2, double score) {
|
||||
INDArray w1Vector = syn0.slice(w1.getIndex());
|
||||
INDArray w2Vector = syn0.slice(w2.getIndex());
|
||||
//prediction: input + bias
|
||||
if (w1.getIndex() < 0 || w1.getIndex() >= syn0.rows())
|
||||
throw new IllegalArgumentException("Illegal index for word " + w1.getLabel());
|
||||
if (w2.getIndex() < 0 || w2.getIndex() >= syn0.rows())
|
||||
throw new IllegalArgumentException("Illegal index for word " + w2.getLabel());
|
||||
|
||||
|
||||
//w1 * w2 + bias
|
||||
double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
|
||||
prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
|
||||
|
||||
double weight = Math.pow(Math.min(1.0, (score / maxCount)), xMax);
|
||||
|
||||
double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
|
||||
if (Double.isNaN(fDiff))
|
||||
fDiff = Nd4j.EPS_THRESHOLD;
|
||||
//amount of change
|
||||
double gradient = fDiff;
|
||||
|
||||
//note the update step here: the gradient is
|
||||
//the gradient of the OPPOSITE word
|
||||
//for adagrad we will use the index of the word passed in
|
||||
//for the gradient calculation we will use the context vector
|
||||
update(w1, w1Vector, w2Vector, gradient);
|
||||
update(w2, w2Vector, w1Vector, gradient);
|
||||
return fDiff;
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
private void update(T w1, INDArray wordVector, INDArray contextVector, double gradient) {
|
||||
//gradient for word vectors
|
||||
INDArray grad1 = contextVector.mul(gradient);
|
||||
INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape());
|
||||
|
||||
//update vector
|
||||
wordVector.subi(update);
|
||||
|
||||
double w1Bias = bias.getDouble(w1.getIndex());
|
||||
double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape());
|
||||
double update2 = w1Bias - biasGradient;
|
||||
bias.putScalar(w1.getIndex(), update2);
|
||||
}
|
||||
|
||||
public AdaGrad getWeightAdaGrad() {
|
||||
return weightAdaGrad;
|
||||
}
|
||||
|
||||
|
||||
public AdaGrad getBiasAdaGrad() {
|
||||
return biasAdaGrad;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Load a glove model from an input stream.
|
||||
* The format is:
|
||||
* word num1 num2....
|
||||
* @param is the input stream to read from for the weights
|
||||
* @param vocab the vocab for the lookuptable
|
||||
* @return the loaded model
|
||||
* @throws java.io.IOException if one occurs
|
||||
*/
|
||||
public static GloveWeightLookupTable load(InputStream is, VocabCache<? extends SequenceElement> vocab)
|
||||
throws IOException {
|
||||
LineIterator iter = IOUtils.lineIterator(is, "UTF-8");
|
||||
GloveWeightLookupTable glove = null;
|
||||
Map<String, float[]> wordVectors = new HashMap<>();
|
||||
while (iter.hasNext()) {
|
||||
String line = iter.nextLine().trim();
|
||||
if (line.isEmpty())
|
||||
continue;
|
||||
String[] split = line.split(" ");
|
||||
String word = split[0];
|
||||
if (glove == null)
|
||||
glove = new GloveWeightLookupTable.Builder().cache(vocab).vectorLength(split.length - 1).build();
|
||||
|
||||
|
||||
|
||||
if (word.isEmpty())
|
||||
continue;
|
||||
float[] read = read(split, glove.layerSize());
|
||||
if (read.length < 1)
|
||||
continue;
|
||||
|
||||
|
||||
wordVectors.put(word, read);
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
glove.setSyn0(weights(glove, wordVectors, vocab));
|
||||
glove.resetWeights(false);
|
||||
|
||||
|
||||
iter.close();
|
||||
|
||||
|
||||
return glove;
|
||||
|
||||
}
|
||||
|
||||
private static INDArray weights(GloveWeightLookupTable glove, Map<String, float[]> data, VocabCache vocab) {
|
||||
INDArray ret = Nd4j.create(data.size(), glove.layerSize());
|
||||
|
||||
for (Map.Entry<String, float[]> entry : data.entrySet()) {
|
||||
String key = entry.getKey();
|
||||
INDArray row = Nd4j.create(Nd4j.createBuffer(entry.getValue()));
|
||||
if (row.length() != glove.layerSize())
|
||||
continue;
|
||||
if (vocab.indexOf(key) >= data.size())
|
||||
continue;
|
||||
if (vocab.indexOf(key) < 0)
|
||||
continue;
|
||||
ret.putRow(vocab.indexOf(key), row);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
private static float[] read(String[] split, int length) {
|
||||
float[] ret = new float[length];
|
||||
for (int i = 1; i < split.length; i++) {
|
||||
ret[i - 1] = Float.parseFloat(split[i]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void iterateSample(T w1, T w2, AtomicLong nextRandom, double alpha) {
|
||||
throw new UnsupportedOperationException();
|
||||
|
||||
}
|
||||
|
||||
public double getxMax() {
|
||||
return xMax;
|
||||
}
|
||||
|
||||
public void setxMax(double xMax) {
|
||||
this.xMax = xMax;
|
||||
}
|
||||
|
||||
public double getMaxCount() {
|
||||
return maxCount;
|
||||
}
|
||||
|
||||
public void setMaxCount(double maxCount) {
|
||||
this.maxCount = maxCount;
|
||||
}
|
||||
|
||||
public INDArray getBias() {
|
||||
return bias;
|
||||
}
|
||||
|
||||
public void setBias(INDArray bias) {
|
||||
this.bias = bias;
|
||||
}
|
||||
|
||||
public static class Builder<T extends SequenceElement> extends InMemoryLookupTable.Builder<T> {
|
||||
private double xMax = 0.75;
|
||||
private double maxCount = 100;
|
||||
|
||||
|
||||
public Builder<T> maxCount(double maxCount) {
|
||||
this.maxCount = maxCount;
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
public Builder<T> xMax(double xMax) {
|
||||
this.xMax = xMax;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder<T> cache(VocabCache<T> vocab) {
|
||||
super.cache(vocab);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder<T> negative(double negative) {
|
||||
super.negative(negative);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder<T> vectorLength(int vectorLength) {
|
||||
super.vectorLength(vectorLength);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder<T> useAdaGrad(boolean useAdaGrad) {
|
||||
super.useAdaGrad(useAdaGrad);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder<T> lr(double lr) {
|
||||
super.lr(lr);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder<T> gen(Random gen) {
|
||||
super.gen(gen);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder<T> seed(long seed) {
|
||||
super.seed(seed);
|
||||
return this;
|
||||
}
|
||||
|
||||
public GloveWeightLookupTable<T> build() {
|
||||
return new GloveWeightLookupTable<>(vocabCache, vectorLength, useAdaGrad, lr, gen, negative, xMax,
|
||||
maxCount);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,91 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove.count;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.PrintWriter;
|
||||
|
||||
/**
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class ASCIICoOccurrenceReader<T extends SequenceElement> implements CoOccurenceReader<T> {
|
||||
private File file;
|
||||
private PrintWriter writer;
|
||||
private SentenceIterator iterator;
|
||||
private VocabCache<T> vocabCache;
|
||||
|
||||
public ASCIICoOccurrenceReader(@NonNull File file, @NonNull VocabCache<T> vocabCache) {
|
||||
this.vocabCache = vocabCache;
|
||||
this.file = file;
|
||||
try {
|
||||
iterator = new PrefetchingSentenceIterator.Builder(new BasicLineIterator(file)).build();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean hasMoreObjects() {
|
||||
return iterator.hasNext();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns next CoOccurrenceWeight object
|
||||
*
|
||||
* PLEASE NOTE: This method can return null value.
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public CoOccurrenceWeight<T> nextObject() {
|
||||
String line = iterator.nextSentence();
|
||||
if (line == null || line.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
String[] strings = line.split(" ");
|
||||
|
||||
CoOccurrenceWeight<T> object = new CoOccurrenceWeight<>();
|
||||
object.setElement1(vocabCache.elementAtIndex(Integer.valueOf(strings[0])));
|
||||
object.setElement2(vocabCache.elementAtIndex(Integer.valueOf(strings[1])));
|
||||
object.setWeight(Double.parseDouble(strings[2]));
|
||||
|
||||
return object;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public void finish() {
|
||||
try {
|
||||
if (writer != null) {
|
||||
writer.flush();
|
||||
writer.close();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove.count;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
|
||||
import java.io.BufferedOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.PrintWriter;
|
||||
|
||||
/**
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class ASCIICoOccurrenceWriter<T extends SequenceElement> implements CoOccurrenceWriter<T> {
|
||||
|
||||
private File file;
|
||||
private PrintWriter writer;
|
||||
|
||||
public ASCIICoOccurrenceWriter(@NonNull File file) {
|
||||
this.file = file;
|
||||
try {
|
||||
this.writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(file), 10 * 1024 * 1024));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeObject(CoOccurrenceWeight<T> object) {
|
||||
StringBuilder builder = new StringBuilder(String.valueOf(object.getElement1().getIndex())).append(" ")
|
||||
.append(String.valueOf(object.getElement2().getIndex())).append(" ")
|
||||
.append(String.valueOf(object.getWeight()));
|
||||
writer.println(builder.toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void queueObject(CoOccurrenceWeight<T> object) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() {
|
||||
try {
|
||||
writer.flush();
|
||||
} catch (Exception e) {
|
||||
}
|
||||
|
||||
try {
|
||||
writer.close();
|
||||
} catch (Exception e) {
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,245 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove.count;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ArrayBlockingQueue;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
/**
|
||||
* Binary implementation of CoOccurenceReader interface, used to provide off-memory storage for cooccurrence maps generated for GloVe
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class BinaryCoOccurrenceReader<T extends SequenceElement> implements CoOccurenceReader<T> {
|
||||
private VocabCache<T> vocabCache;
|
||||
private InputStream inputStream;
|
||||
private File file;
|
||||
private ArrayBlockingQueue<CoOccurrenceWeight<T>> buffer;
|
||||
int workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
|
||||
private StreamReaderThread readerThread;
|
||||
private CountMap<T> countMap;
|
||||
|
||||
|
||||
protected static final Logger logger = LoggerFactory.getLogger(BinaryCoOccurrenceReader.class);
|
||||
|
||||
public BinaryCoOccurrenceReader(@NonNull File file, @NonNull VocabCache<T> vocabCache, CountMap<T> map) {
|
||||
this.vocabCache = vocabCache;
|
||||
this.file = file;
|
||||
this.countMap = map;
|
||||
buffer = new ArrayBlockingQueue<>(200000);
|
||||
|
||||
try {
|
||||
inputStream = new BufferedInputStream(new FileInputStream(this.file), 100 * 1024 * 1024);
|
||||
//inputStream = new BufferedInputStream(new FileInputStream(file), 1024 * 1024);
|
||||
readerThread = new StreamReaderThread(inputStream);
|
||||
readerThread.start();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasMoreObjects() {
|
||||
|
||||
if (!buffer.isEmpty())
|
||||
return true;
|
||||
|
||||
try {
|
||||
return readerThread.hasMoreObjects() || !buffer.isEmpty();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
//return false;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public CoOccurrenceWeight<T> nextObject() {
|
||||
if (!buffer.isEmpty()) {
|
||||
return buffer.poll();
|
||||
} else {
|
||||
// buffer can be starved, or we're already at the end of file.
|
||||
if (readerThread.hasMoreObjects()) {
|
||||
try {
|
||||
return buffer.poll(3, TimeUnit.SECONDS);
|
||||
} catch (Exception e) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return null;
|
||||
/*
|
||||
try {
|
||||
CoOccurrenceWeight<T> ret = new CoOccurrenceWeight<>();
|
||||
ret.setElement1(vocabCache.elementAtIndex(inputStream.readInt()));
|
||||
ret.setElement2(vocabCache.elementAtIndex(inputStream.readInt()));
|
||||
ret.setWeight(inputStream.readDouble());
|
||||
|
||||
return ret;
|
||||
} catch (Exception e) {
|
||||
return null;
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() {
|
||||
try {
|
||||
if (inputStream != null)
|
||||
inputStream.close();
|
||||
} catch (Exception e) {
|
||||
//
|
||||
}
|
||||
}
|
||||
|
||||
private class StreamReaderThread extends Thread implements Runnable {
|
||||
private InputStream stream;
|
||||
private AtomicBoolean isReading = new AtomicBoolean(false);
|
||||
|
||||
public StreamReaderThread(@NonNull InputStream stream) {
|
||||
this.stream = stream;
|
||||
isReading.set(false);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
// we read pre-defined number of objects as byte array
|
||||
byte[] array = new byte[16 * 500000];
|
||||
while (true) {
|
||||
int count = stream.read(array);
|
||||
|
||||
isReading.set(true);
|
||||
if (count == 0)
|
||||
break;
|
||||
|
||||
// now we deserialize them in separate threads to gain some speedup, if possible
|
||||
List<AsyncDeserializationThread> threads = new ArrayList<>();
|
||||
AtomicInteger internalPosition = new AtomicInteger(0);
|
||||
|
||||
for (int t = 0; t < workers; t++) {
|
||||
threads.add(t, new AsyncDeserializationThread(t, array, buffer, internalPosition, count));
|
||||
threads.get(t).start();
|
||||
}
|
||||
|
||||
// we'll block this cycle untill all objects are fit into queue
|
||||
for (int t = 0; t < workers; t++) {
|
||||
try {
|
||||
threads.get(t).join();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
isReading.set(false);
|
||||
if (count < array.length)
|
||||
break;
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
isReading.set(false);
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public boolean hasMoreObjects() {
|
||||
try {
|
||||
return stream.available() > 0 || isReading.get();
|
||||
} catch (Exception e) {
|
||||
return false;
|
||||
} finally {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility class that accepts byte array as input, and deserialize it into set of CoOccurrenceWeight objects
|
||||
*/
|
||||
private class AsyncDeserializationThread extends Thread implements Runnable {
|
||||
private int threadId;
|
||||
private byte[] arrayReference;
|
||||
private ArrayBlockingQueue<CoOccurrenceWeight<T>> targetBuffer;
|
||||
private AtomicInteger pointer;
|
||||
private int limit;
|
||||
|
||||
public AsyncDeserializationThread(int threadId, @NonNull byte[] array,
|
||||
@NonNull ArrayBlockingQueue<CoOccurrenceWeight<T>> targetBuffer,
|
||||
@NonNull AtomicInteger sharedPointer, int limit) {
|
||||
this.threadId = threadId;
|
||||
this.arrayReference = array;
|
||||
this.targetBuffer = targetBuffer;
|
||||
this.pointer = sharedPointer;
|
||||
this.limit = limit;
|
||||
|
||||
|
||||
setName("AsynDeserialization thread " + this.threadId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
ByteBuffer bB = ByteBuffer.wrap(arrayReference);
|
||||
int position = 0;
|
||||
while ((position = pointer.getAndAdd(16)) < this.limit) {
|
||||
if (position >= limit) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
int e1idx = bB.getInt(position);
|
||||
int e2idx = bB.getInt(position + 4);
|
||||
double eW = bB.getDouble(position + 8);
|
||||
|
||||
|
||||
CoOccurrenceWeight<T> object = new CoOccurrenceWeight<>();
|
||||
object.setElement1(vocabCache.elementAtIndex(e1idx));
|
||||
object.setElement2(vocabCache.elementAtIndex(e2idx));
|
||||
|
||||
if (countMap != null) {
|
||||
double mW = countMap.getCount(object.getElement1(), object.getElement2());
|
||||
|
||||
if (mW > 0) {
|
||||
eW += mW;
|
||||
countMap.removePair(object.getElement1(), object.getElement2());
|
||||
}
|
||||
}
|
||||
object.setWeight(eW);
|
||||
|
||||
try {
|
||||
targetBuffer.put(object);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,78 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove.count;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.BufferedOutputStream;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
|
||||
/**
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class BinaryCoOccurrenceWriter<T extends SequenceElement> implements CoOccurrenceWriter<T> {
|
||||
private File file;
|
||||
private DataOutputStream outputStream;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(BinaryCoOccurrenceWriter.class);
|
||||
|
||||
public BinaryCoOccurrenceWriter(@NonNull File file) {
|
||||
this.file = file;
|
||||
|
||||
try {
|
||||
outputStream = new DataOutputStream(
|
||||
new BufferedOutputStream(new FileOutputStream(file), 100 * 1024 * 1024));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeObject(@NonNull CoOccurrenceWeight<T> object) {
|
||||
try {
|
||||
// log.info("Saving objects: { [" +object.getElement1().getIndex() +"], [" + object.getElement2().getIndex() + "] }");
|
||||
outputStream.writeInt(object.getElement1().getIndex());
|
||||
outputStream.writeInt(object.getElement2().getIndex());
|
||||
outputStream.writeDouble(object.getWeight());
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void queueObject(CoOccurrenceWeight<T> object) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() {
|
||||
try {
|
||||
outputStream.flush();
|
||||
} catch (Exception e) {
|
||||
}
|
||||
|
||||
try {
|
||||
outputStream.close();
|
||||
} catch (Exception e) {
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,34 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove.count;
|
||||
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
|
||||
/**
|
||||
* Created by raver on 24.12.2015.
|
||||
*/
|
||||
public interface CoOccurenceReader<T extends SequenceElement> {
|
||||
/*
|
||||
Storage->Memory merging part
|
||||
*/
|
||||
boolean hasMoreObjects();
|
||||
|
||||
|
||||
CoOccurrenceWeight<T> nextObject();
|
||||
|
||||
void finish();
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove.count;
|
||||
|
||||
import lombok.Data;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
|
||||
/**
|
||||
* Simple POJO holding pairs of elements and their respective weights, used in GloVe -> CoOccurrence
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
@Data
|
||||
public class CoOccurrenceWeight<T extends SequenceElement> {
|
||||
private T element1;
|
||||
private T element2;
|
||||
private double weight;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o)
|
||||
return true;
|
||||
if (o == null || getClass() != o.getClass())
|
||||
return false;
|
||||
|
||||
CoOccurrenceWeight<?> that = (CoOccurrenceWeight<?>) o;
|
||||
|
||||
if (element1 != null ? !element1.equals(that.element1) : that.element1 != null)
|
||||
return false;
|
||||
return element2 != null ? element2.equals(that.element2) : that.element2 == null;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = element1 != null ? element1.hashCode() : 0;
|
||||
result = 31 * result + (element2 != null ? element2.hashCode() : 0);
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove.count;
|
||||
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
|
||||
/**
|
||||
* Created by fartovii on 25.12.15.
|
||||
*/
|
||||
public interface CoOccurrenceWriter<T extends SequenceElement> {
|
||||
|
||||
/**
|
||||
* This method implementations should write out objects immediately
|
||||
* @param object
|
||||
*/
|
||||
void writeObject(CoOccurrenceWeight<T> object);
|
||||
|
||||
/**
|
||||
* This method implementations should queue objects for writing out.
|
||||
*
|
||||
* @param object
|
||||
*/
|
||||
void queueObject(CoOccurrenceWeight<T> object);
|
||||
|
||||
/**
|
||||
* Implementations of this method should close everything they use, before eradication
|
||||
*/
|
||||
void finish();
|
||||
}
|
|
@ -1,99 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove.count;
|
||||
|
||||
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
import java.util.Iterator;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* Drop-in replacement for CounterMap
|
||||
*
|
||||
* WORK IN PROGRESS, PLEASE DO NOT USE
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class CountMap<T extends SequenceElement> {
|
||||
private volatile Map<Pair<T, T>, AtomicDouble> backingMap = new ConcurrentHashMap<>();
|
||||
|
||||
public CountMap() {
|
||||
// placeholder
|
||||
}
|
||||
|
||||
public void incrementCount(T element1, T element2, double weight) {
|
||||
Pair<T, T> tempEntry = new Pair<>(element1, element2);
|
||||
if (backingMap.containsKey(tempEntry)) {
|
||||
backingMap.get(tempEntry).addAndGet(weight);
|
||||
} else {
|
||||
backingMap.put(tempEntry, new AtomicDouble(weight));
|
||||
}
|
||||
}
|
||||
|
||||
public void removePair(T element1, T element2) {
|
||||
Pair<T, T> tempEntry = new Pair<>(element1, element2);
|
||||
backingMap.remove(tempEntry);
|
||||
}
|
||||
|
||||
public void removePair(Pair<T, T> pair) {
|
||||
backingMap.remove(pair);
|
||||
}
|
||||
|
||||
public double getCount(T element1, T element2) {
|
||||
Pair<T, T> tempEntry = new Pair<>(element1, element2);
|
||||
if (backingMap.containsKey(tempEntry)) {
|
||||
return backingMap.get(tempEntry).get();
|
||||
} else
|
||||
return 0;
|
||||
}
|
||||
|
||||
public double getCount(Pair<T, T> pair) {
|
||||
if (backingMap.containsKey(pair)) {
|
||||
return backingMap.get(pair).get();
|
||||
} else
|
||||
return 0;
|
||||
}
|
||||
|
||||
public Iterator<Pair<T, T>> getPairIterator() {
|
||||
return new Iterator<Pair<T, T>>() {
|
||||
private Iterator<Pair<T, T>> iterator = backingMap.keySet().iterator();
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return iterator.hasNext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<T, T> next() {
|
||||
//MapEntry<T> entry = iterator.next();
|
||||
return iterator.next(); //new Pair<>(entry.getElement1(), entry.getElement2());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void remove() {
|
||||
throw new UnsupportedOperationException("remove() isn't supported here");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public int size() {
|
||||
return backingMap.size();
|
||||
}
|
||||
}
|
|
@ -1,86 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove.count;
|
||||
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
|
||||
/**
|
||||
* Simple circular counter, that circulates within 0...Limit, both inclusive
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class RoundCount {
|
||||
|
||||
private int limit = 0;
|
||||
private int lower = 0;
|
||||
private int value = 0;
|
||||
|
||||
private ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
|
||||
|
||||
/**
|
||||
* Creates new RoundCount instance.
|
||||
*
|
||||
* @param limit Maximum top value for this counter. Inclusive.
|
||||
*/
|
||||
public RoundCount(int limit) {
|
||||
this.limit = limit;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates new RoundCount instance.
|
||||
*
|
||||
* @param lower - Minimum value for this counter. Inclusive
|
||||
* @param top - Maximum value for this counter. Inclusive.
|
||||
*/
|
||||
public RoundCount(int lower, int top) {
|
||||
this.limit = top;
|
||||
this.lower = lower;
|
||||
}
|
||||
|
||||
public int previous() {
|
||||
try {
|
||||
lock.readLock().lock();
|
||||
if (value == lower)
|
||||
return limit;
|
||||
else
|
||||
return value - 1;
|
||||
} finally {
|
||||
lock.readLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public int get() {
|
||||
try {
|
||||
lock.readLock().lock();
|
||||
return value;
|
||||
} finally {
|
||||
lock.readLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void tick() {
|
||||
try {
|
||||
lock.writeLock().lock();
|
||||
if (value == limit)
|
||||
value = lower;
|
||||
else
|
||||
value++;
|
||||
} finally {
|
||||
lock.writeLock().unlock();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -763,7 +763,7 @@ public class ParagraphVectors extends Word2Vec {
|
|||
|
||||
|
||||
/**
|
||||
* This method allows you to use pre-built WordVectors model (Word2Vec or GloVe) for ParagraphVectors.
|
||||
* This method allows you to use pre-built WordVectors model (e.g. Word2Vec) for ParagraphVectors.
|
||||
* Existing model will be transferred into new model before training starts.
|
||||
*
|
||||
* PLEASE NOTE: Non-normalized model is recommended to use here.
|
||||
|
|
|
@ -520,7 +520,7 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
|
|||
}
|
||||
|
||||
/**
|
||||
* This method allows you to use pre-built WordVectors model (SkipGram or GloVe) for DBOW sequence learning.
|
||||
* This method allows you to use pre-built WordVectors model (e.g. SkipGram) for DBOW sequence learning.
|
||||
* Existing model will be transferred into new model before training starts.
|
||||
*
|
||||
* PLEASE NOTE: This model has no effect for elements learning algorithms. Only sequence learning is affected.
|
||||
|
|
|
@ -1,101 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
|
||||
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
|
||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabConstructor;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotEquals;
|
||||
|
||||
/**
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class AbstractCoOccurrencesTest extends BaseDL4JTest {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(AbstractCoOccurrencesTest.class);
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFit1() throws Exception {
|
||||
ClassPathResource resource = new ClassPathResource("other/oneline.txt");
|
||||
File file = resource.getFile();
|
||||
|
||||
AbstractCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
|
||||
BasicLineIterator underlyingIterator = new BasicLineIterator(file);
|
||||
|
||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||
|
||||
SentenceTransformer transformer =
|
||||
new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
|
||||
|
||||
AbstractSequenceIterator<VocabWord> sequenceIterator =
|
||||
new AbstractSequenceIterator.Builder<>(transformer).build();
|
||||
|
||||
VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>()
|
||||
.addSource(sequenceIterator, 1).setTargetVocabCache(vocabCache).build();
|
||||
|
||||
constructor.buildJointVocabulary(false, true);
|
||||
|
||||
AbstractCoOccurrences<VocabWord> coOccurrences = new AbstractCoOccurrences.Builder<VocabWord>()
|
||||
.iterate(sequenceIterator).vocabCache(vocabCache).symmetric(false).windowSize(15).build();
|
||||
|
||||
coOccurrences.fit();
|
||||
|
||||
//List<Pair<VocabWord, VocabWord>> list = coOccurrences.i();
|
||||
Iterator<Pair<Pair<VocabWord, VocabWord>, Double>> iterator = coOccurrences.iterator();
|
||||
assertNotEquals(null, iterator);
|
||||
int cnt = 0;
|
||||
|
||||
List<Pair<VocabWord, VocabWord>> list = new ArrayList<>();
|
||||
while (iterator.hasNext()) {
|
||||
Pair<Pair<VocabWord, VocabWord>, Double> pair = iterator.next();
|
||||
list.add(pair.getFirst());
|
||||
cnt++;
|
||||
}
|
||||
|
||||
|
||||
log.info("CoOccurrences: " + list);
|
||||
|
||||
assertEquals(16, list.size());
|
||||
assertEquals(16, cnt);
|
||||
}
|
||||
}
|
|
@ -1,137 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor;
|
||||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
import org.junit.Before;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.Collection;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
/**
|
||||
* Created by agibsonccc on 12/3/14.
|
||||
*/
|
||||
public class GloveTest extends BaseDL4JTest {
|
||||
private static final Logger log = LoggerFactory.getLogger(GloveTest.class);
|
||||
private Glove glove;
|
||||
private SentenceIterator iter;
|
||||
|
||||
@Before
|
||||
public void before() throws Exception {
|
||||
|
||||
ClassPathResource resource = new ClassPathResource("/raw_sentences.txt");
|
||||
File file = resource.getFile();
|
||||
iter = new LineSentenceIterator(file);
|
||||
iter.setPreProcessor(new SentencePreProcessor() {
|
||||
@Override
|
||||
public String preProcess(String sentence) {
|
||||
return sentence.toLowerCase();
|
||||
}
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void testGlove() throws Exception {
|
||||
/*
|
||||
glove = new Glove.Builder().iterate(iter).symmetric(true).shuffle(true)
|
||||
.minWordFrequency(1).iterations(10).learningRate(0.1)
|
||||
.layerSize(300)
|
||||
.build();
|
||||
|
||||
glove.fit();
|
||||
Collection<String> words = glove.wordsNearest("day", 20);
|
||||
log.info("Nearest words to 'day': " + words);
|
||||
assertTrue(words.contains("week"));
|
||||
|
||||
*/
|
||||
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void testGloVe1() throws Exception {
|
||||
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||
|
||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||
// Split on white spaces in the line to get words
|
||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||
|
||||
Glove glove = new Glove.Builder().iterate(iter).tokenizerFactory(t).alpha(0.75).learningRate(0.1).epochs(45)
|
||||
.xMax(100).shuffle(true).symmetric(true).build();
|
||||
|
||||
glove.fit();
|
||||
|
||||
double simD = glove.similarity("day", "night");
|
||||
double simP = glove.similarity("best", "police");
|
||||
|
||||
|
||||
|
||||
log.info("Day/night similarity: " + simD);
|
||||
log.info("Best/police similarity: " + simP);
|
||||
|
||||
Collection<String> words = glove.wordsNearest("day", 10);
|
||||
log.info("Nearest words to 'day': " + words);
|
||||
|
||||
|
||||
assertTrue(simD > 0.7);
|
||||
|
||||
// actually simP should be somewhere at 0
|
||||
assertTrue(simP < 0.5);
|
||||
|
||||
assertTrue(words.contains("night"));
|
||||
assertTrue(words.contains("year"));
|
||||
assertTrue(words.contains("week"));
|
||||
|
||||
File tempFile = File.createTempFile("glove", "temp");
|
||||
tempFile.deleteOnExit();
|
||||
|
||||
INDArray day1 = glove.getWordVectorMatrix("day").dup();
|
||||
|
||||
WordVectorSerializer.writeWordVectors(glove, tempFile);
|
||||
|
||||
WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
|
||||
|
||||
INDArray day2 = vectors.getWordVectorMatrix("day").dup();
|
||||
|
||||
assertEquals(day1, day2);
|
||||
|
||||
tempFile.delete();
|
||||
}
|
||||
}
|
|
@ -1,156 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove.count;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.models.word2vec.Huffman;
|
||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
import static org.junit.Assert.assertNotEquals;
|
||||
|
||||
/**
|
||||
* Created by fartovii on 25.12.15.
|
||||
*/
|
||||
public class BinaryCoOccurrenceReaderTest extends BaseDL4JTest {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(BinaryCoOccurrenceReaderTest.class);
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testHasMoreObjects1() throws Exception {
|
||||
File tempFile = File.createTempFile("tmp", "tmp");
|
||||
tempFile.deleteOnExit();
|
||||
|
||||
VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
|
||||
|
||||
VocabWord word1 = new VocabWord(1.0, "human");
|
||||
VocabWord word2 = new VocabWord(2.0, "animal");
|
||||
VocabWord word3 = new VocabWord(3.0, "unknown");
|
||||
|
||||
vocabCache.addToken(word1);
|
||||
vocabCache.addToken(word2);
|
||||
vocabCache.addToken(word3);
|
||||
|
||||
Huffman huffman = new Huffman(vocabCache.vocabWords());
|
||||
huffman.build();
|
||||
huffman.applyIndexes(vocabCache);
|
||||
|
||||
|
||||
BinaryCoOccurrenceWriter<VocabWord> writer = new BinaryCoOccurrenceWriter<>(tempFile);
|
||||
|
||||
CoOccurrenceWeight<VocabWord> object1 = new CoOccurrenceWeight<>();
|
||||
object1.setElement1(word1);
|
||||
object1.setElement2(word2);
|
||||
object1.setWeight(3.14159265);
|
||||
|
||||
writer.writeObject(object1);
|
||||
|
||||
CoOccurrenceWeight<VocabWord> object2 = new CoOccurrenceWeight<>();
|
||||
object2.setElement1(word2);
|
||||
object2.setElement2(word3);
|
||||
object2.setWeight(0.197);
|
||||
|
||||
writer.writeObject(object2);
|
||||
|
||||
writer.finish();
|
||||
|
||||
BinaryCoOccurrenceReader<VocabWord> reader = new BinaryCoOccurrenceReader<>(tempFile, vocabCache, null);
|
||||
|
||||
|
||||
CoOccurrenceWeight<VocabWord> r1 = reader.nextObject();
|
||||
log.info("Object received: " + r1);
|
||||
assertNotEquals(null, r1);
|
||||
|
||||
r1 = reader.nextObject();
|
||||
log.info("Object received: " + r1);
|
||||
assertNotEquals(null, r1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testHasMoreObjects2() throws Exception {
|
||||
File tempFile = File.createTempFile("tmp", "tmp");
|
||||
tempFile.deleteOnExit();
|
||||
|
||||
VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
|
||||
|
||||
VocabWord word1 = new VocabWord(1.0, "human");
|
||||
VocabWord word2 = new VocabWord(2.0, "animal");
|
||||
VocabWord word3 = new VocabWord(3.0, "unknown");
|
||||
|
||||
vocabCache.addToken(word1);
|
||||
vocabCache.addToken(word2);
|
||||
vocabCache.addToken(word3);
|
||||
|
||||
Huffman huffman = new Huffman(vocabCache.vocabWords());
|
||||
huffman.build();
|
||||
huffman.applyIndexes(vocabCache);
|
||||
|
||||
|
||||
BinaryCoOccurrenceWriter<VocabWord> writer = new BinaryCoOccurrenceWriter<>(tempFile);
|
||||
|
||||
CoOccurrenceWeight<VocabWord> object1 = new CoOccurrenceWeight<>();
|
||||
object1.setElement1(word1);
|
||||
object1.setElement2(word2);
|
||||
object1.setWeight(3.14159265);
|
||||
|
||||
writer.writeObject(object1);
|
||||
|
||||
CoOccurrenceWeight<VocabWord> object2 = new CoOccurrenceWeight<>();
|
||||
object2.setElement1(word2);
|
||||
object2.setElement2(word3);
|
||||
object2.setWeight(0.197);
|
||||
|
||||
writer.writeObject(object2);
|
||||
|
||||
CoOccurrenceWeight<VocabWord> object3 = new CoOccurrenceWeight<>();
|
||||
object3.setElement1(word1);
|
||||
object3.setElement2(word3);
|
||||
object3.setWeight(0.001);
|
||||
|
||||
writer.writeObject(object3);
|
||||
|
||||
writer.finish();
|
||||
|
||||
BinaryCoOccurrenceReader<VocabWord> reader = new BinaryCoOccurrenceReader<>(tempFile, vocabCache, null);
|
||||
|
||||
|
||||
CoOccurrenceWeight<VocabWord> r1 = reader.nextObject();
|
||||
log.info("Object received: " + r1);
|
||||
assertNotEquals(null, r1);
|
||||
|
||||
r1 = reader.nextObject();
|
||||
log.info("Object received: " + r1);
|
||||
assertNotEquals(null, r1);
|
||||
|
||||
r1 = reader.nextObject();
|
||||
log.info("Object received: " + r1);
|
||||
assertNotEquals(null, r1);
|
||||
|
||||
}
|
||||
}
|
|
@ -1,90 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.models.glove.count;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
/**
|
||||
* Created by fartovii on 23.12.15.
|
||||
*/
|
||||
public class RoundCountTest extends BaseDL4JTest {
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGet1() throws Exception {
|
||||
RoundCount count = new RoundCount(1);
|
||||
|
||||
assertEquals(0, count.get());
|
||||
|
||||
count.tick();
|
||||
assertEquals(1, count.get());
|
||||
|
||||
count.tick();
|
||||
assertEquals(0, count.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGet2() throws Exception {
|
||||
RoundCount count = new RoundCount(3);
|
||||
|
||||
assertEquals(0, count.get());
|
||||
|
||||
count.tick();
|
||||
assertEquals(1, count.get());
|
||||
|
||||
count.tick();
|
||||
assertEquals(2, count.get());
|
||||
|
||||
count.tick();
|
||||
assertEquals(3, count.get());
|
||||
|
||||
count.tick();
|
||||
assertEquals(0, count.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPrevious1() throws Exception {
|
||||
RoundCount count = new RoundCount(3);
|
||||
|
||||
assertEquals(0, count.get());
|
||||
assertEquals(3, count.previous());
|
||||
|
||||
count.tick();
|
||||
assertEquals(1, count.get());
|
||||
assertEquals(0, count.previous());
|
||||
|
||||
count.tick();
|
||||
assertEquals(2, count.get());
|
||||
assertEquals(1, count.previous());
|
||||
|
||||
count.tick();
|
||||
assertEquals(3, count.get());
|
||||
assertEquals(2, count.previous());
|
||||
|
||||
count.tick();
|
||||
assertEquals(0, count.get());
|
||||
assertEquals(3, count.previous());
|
||||
}
|
||||
}
|
|
@ -21,12 +21,10 @@ import lombok.Getter;
|
|||
import lombok.Setter;
|
||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||
import org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe;
|
||||
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
|
||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||
|
@ -55,6 +53,7 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
|||
import org.junit.Before;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.heartbeat.Heartbeat;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -270,65 +269,6 @@ public class SequenceVectorsTest extends BaseDL4JTest {
|
|||
.epochs(1).resetModel(false).trainElementsRepresentation(false).build();
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void testGlove1() throws Exception {
|
||||
logger.info("Max available memory: " + Runtime.getRuntime().maxMemory());
|
||||
ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
|
||||
File file = resource.getFile();
|
||||
|
||||
BasicLineIterator underlyingIterator = new BasicLineIterator(file);
|
||||
|
||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||
|
||||
SentenceTransformer transformer =
|
||||
new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
|
||||
|
||||
AbstractSequenceIterator<VocabWord> sequenceIterator =
|
||||
new AbstractSequenceIterator.Builder<>(transformer).build();
|
||||
|
||||
VectorsConfiguration configuration = new VectorsConfiguration();
|
||||
configuration.setWindow(5);
|
||||
configuration.setLearningRate(0.06);
|
||||
configuration.setLayersSize(100);
|
||||
|
||||
|
||||
SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(configuration)
|
||||
.iterate(sequenceIterator).iterations(1).epochs(45)
|
||||
.elementsLearningAlgorithm(new GloVe.Builder<VocabWord>().shuffle(true).symmetric(true)
|
||||
.learningRate(0.05).alpha(0.75).xMax(100.0).build())
|
||||
.resetModel(true).trainElementsRepresentation(true).trainSequencesRepresentation(false).build();
|
||||
|
||||
vectors.fit();
|
||||
|
||||
double sim = vectors.similarity("day", "night");
|
||||
logger.info("Day/night similarity: " + sim);
|
||||
|
||||
|
||||
sim = vectors.similarity("day", "another");
|
||||
logger.info("Day/another similarity: " + sim);
|
||||
|
||||
sim = vectors.similarity("night", "year");
|
||||
logger.info("Night/year similarity: " + sim);
|
||||
|
||||
sim = vectors.similarity("night", "me");
|
||||
logger.info("Night/me similarity: " + sim);
|
||||
|
||||
sim = vectors.similarity("day", "know");
|
||||
logger.info("Day/know similarity: " + sim);
|
||||
|
||||
sim = vectors.similarity("best", "police");
|
||||
logger.info("Best/police similarity: " + sim);
|
||||
|
||||
Collection<String> labels = vectors.wordsNearest("day", 10);
|
||||
logger.info("Nearest labels to 'day': " + labels);
|
||||
|
||||
|
||||
sim = vectors.similarity("day", "night");
|
||||
assertTrue(sim > 0.6d);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testDeepWalk() throws Exception {
|
||||
|
|
|
@ -49,8 +49,8 @@ import java.io.File;
|
|||
import java.util.Collection;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
import static org.awaitility.Awaitility.await;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
|
||||
@Slf4j
|
||||
|
@ -206,12 +206,6 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
|
|||
final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
||||
|
||||
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
||||
await()
|
||||
.until(new Callable<Boolean>() {
|
||||
@Override
|
||||
public Boolean call() {
|
||||
return net.params().equalsWithEps(restored.params(), 2e-3);
|
||||
}
|
||||
});
|
||||
assertTrue(net.params().equalsWithEps(restored.params(), 2e-3));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
|||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.optimize.Solver;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -247,10 +248,8 @@ public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.c
|
|||
@Override
|
||||
public int[] predict(INDArray input) {
|
||||
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
|
||||
int[] ret = new int[input.rows()];
|
||||
for (int i = 0; i < ret.length; i++)
|
||||
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
|
||||
return ret;
|
||||
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
|
||||
return output.argMax(1).toIntVector();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.optimize.Solver;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
@ -251,10 +252,8 @@ public class LossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.LossL
|
|||
@Override
|
||||
public int[] predict(INDArray input) {
|
||||
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
|
||||
int[] ret = new int[input.rows()];
|
||||
for (int i = 0; i < ret.length; i++)
|
||||
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
|
||||
return ret;
|
||||
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
|
||||
return output.argMax(1).toIntVector();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -2220,14 +2220,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
if (d.size(0) > Integer.MAX_VALUE)
|
||||
throw new ND4JArraySizeException();
|
||||
|
||||
int[] ret = new int[(int) d.size(0)];
|
||||
if (d.isRowVectorOrScalar())
|
||||
ret[0] = Nd4j.getBlasWrapper().iamax(output);
|
||||
else {
|
||||
for (int i = 0; i < ret.length; i++)
|
||||
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
|
||||
}
|
||||
return ret;
|
||||
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
|
||||
return output.argMax(1).toIntVector();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -1,280 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.spark.models.embeddings.glove;
|
||||
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.apache.spark.SparkConf;
|
||||
import org.apache.spark.api.java.JavaPairRDD;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.api.java.function.Function;
|
||||
import org.apache.spark.api.java.function.PairFunction;
|
||||
import org.apache.spark.broadcast.Broadcast;
|
||||
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
|
||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCalculator;
|
||||
import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCounts;
|
||||
import org.deeplearning4j.spark.text.functions.TextPipeline;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.legacy.AdaGrad;
|
||||
import org.nd4j.common.primitives.CounterMap;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.common.primitives.Triple;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import scala.Tuple2;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
import static org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables.*;
|
||||
|
||||
/**
|
||||
* Spark glove
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class Glove implements Serializable {
|
||||
|
||||
private Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast;
|
||||
private String tokenizerFactoryClazz = DefaultTokenizerFactory.class.getName();
|
||||
private boolean symmetric = true;
|
||||
private int windowSize = 15;
|
||||
private int iterations = 300;
|
||||
private static Logger log = LoggerFactory.getLogger(Glove.class);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param tokenizerFactoryClazz the fully qualified class name of the tokenizer
|
||||
* @param symmetric whether the co occurrence counts should be symmetric
|
||||
* @param windowSize the window size for co occurrence
|
||||
* @param iterations the number of iterations
|
||||
*/
|
||||
public Glove(String tokenizerFactoryClazz, boolean symmetric, int windowSize, int iterations) {
|
||||
this.tokenizerFactoryClazz = tokenizerFactoryClazz;
|
||||
this.symmetric = symmetric;
|
||||
this.windowSize = windowSize;
|
||||
this.iterations = iterations;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param symmetric whether the co occurrence counts should be symmetric
|
||||
* @param windowSize the window size for co occurrence
|
||||
* @param iterations the number of iterations
|
||||
*/
|
||||
public Glove(boolean symmetric, int windowSize, int iterations) {
|
||||
this.symmetric = symmetric;
|
||||
this.windowSize = windowSize;
|
||||
this.iterations = iterations;
|
||||
}
|
||||
|
||||
|
||||
private Pair<INDArray, Float> update(AdaGrad weightAdaGrad, AdaGrad biasAdaGrad, INDArray syn0, INDArray bias,
|
||||
VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) {
|
||||
//gradient for word vectors
|
||||
INDArray grad1 = contextVector.mul(gradient);
|
||||
INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape());
|
||||
wordVector.subi(update);
|
||||
|
||||
double w1Bias = bias.getDouble(w1.getIndex());
|
||||
double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape());
|
||||
double update2 = w1Bias - biasGradient;
|
||||
bias.putScalar(w1.getIndex(), bias.getDouble(w1.getIndex()) - update2);
|
||||
return new Pair<>(update, (float) update2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Train on the corpus
|
||||
* @param rdd the rdd to train
|
||||
* @return the vocab and weights
|
||||
*/
|
||||
public Pair<VocabCache<VocabWord>, GloveWeightLookupTable> train(JavaRDD<String> rdd) throws Exception {
|
||||
// Each `train()` can use different parameters
|
||||
final JavaSparkContext sc = new JavaSparkContext(rdd.context());
|
||||
final SparkConf conf = sc.getConf();
|
||||
final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class);
|
||||
final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class);
|
||||
final double negative = assignVar(NEGATIVE, conf, Double.class);
|
||||
final int numWords = assignVar(NUM_WORDS, conf, Integer.class);
|
||||
final int window = assignVar(WINDOW, conf, Integer.class);
|
||||
final double alpha = assignVar(ALPHA, conf, Double.class);
|
||||
final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class);
|
||||
final int iterations = assignVar(ITERATIONS, conf, Integer.class);
|
||||
final int nGrams = assignVar(N_GRAMS, conf, Integer.class);
|
||||
final String tokenizer = assignVar(TOKENIZER, conf, String.class);
|
||||
final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class);
|
||||
final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class);
|
||||
|
||||
Map<String, Object> tokenizerVarMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("numWords", numWords);
|
||||
put("nGrams", nGrams);
|
||||
put("tokenizer", tokenizer);
|
||||
put("tokenPreprocessor", tokenPreprocessor);
|
||||
put("removeStop", removeStop);
|
||||
}
|
||||
};
|
||||
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
|
||||
|
||||
|
||||
TextPipeline pipeline = new TextPipeline(rdd, broadcastTokenizerVarMap);
|
||||
pipeline.buildVocabCache();
|
||||
pipeline.buildVocabWordListRDD();
|
||||
|
||||
|
||||
// Get total word count
|
||||
Long totalWordCount = pipeline.getTotalWordCount();
|
||||
VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
|
||||
JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD = pipeline.getSentenceWordsCountRDD();
|
||||
final Pair<VocabCache<VocabWord>, Long> vocabAndNumWords = new Pair<>(vocabCache, totalWordCount);
|
||||
|
||||
vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());
|
||||
|
||||
final GloveWeightLookupTable gloveWeightLookupTable = new GloveWeightLookupTable.Builder()
|
||||
.cache(vocabAndNumWords.getFirst()).lr(conf.getDouble(GlovePerformer.ALPHA, 0.01))
|
||||
.maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100))
|
||||
.vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300))
|
||||
.xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75)).build();
|
||||
gloveWeightLookupTable.resetWeights();
|
||||
|
||||
gloveWeightLookupTable.getBiasAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().rows());
|
||||
gloveWeightLookupTable.getWeightAdaGrad().historicalGradient =
|
||||
Nd4j.ones(gloveWeightLookupTable.getSyn0().shape());
|
||||
|
||||
|
||||
log.info("Created lookup table of size " + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
|
||||
CounterMap<String, String> coOccurrenceCounts = sentenceWordsCountRDD
|
||||
.map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize))
|
||||
.fold(new CounterMap<String, String>(), new CoOccurrenceCounts());
|
||||
Iterator<Pair<String, String>> pair2 = coOccurrenceCounts.getIterator();
|
||||
List<Triple<String, String, Float>> counts = new ArrayList<>();
|
||||
|
||||
while (pair2.hasNext()) {
|
||||
Pair<String, String> next = pair2.next();
|
||||
if (coOccurrenceCounts.getCount(next.getFirst(), next.getSecond()) > gloveWeightLookupTable.getMaxCount()) {
|
||||
coOccurrenceCounts.setCount(next.getFirst(), next.getSecond(),
|
||||
(float) gloveWeightLookupTable.getMaxCount());
|
||||
}
|
||||
counts.add(new Triple<>(next.getFirst(), next.getSecond(),
|
||||
(float) coOccurrenceCounts.getCount(next.getFirst(), next.getSecond())));
|
||||
|
||||
}
|
||||
|
||||
log.info("Calculated co occurrences");
|
||||
|
||||
JavaRDD<Triple<String, String, Float>> parallel = sc.parallelize(counts);
|
||||
JavaPairRDD<String, Tuple2<String, Float>> pairs = parallel
|
||||
.mapToPair(new PairFunction<Triple<String, String, Float>, String, Tuple2<String, Float>>() {
|
||||
@Override
|
||||
public Tuple2<String, Tuple2<String, Float>> call(
|
||||
Triple<String, String, Float> stringStringDoubleTriple) throws Exception {
|
||||
return new Tuple2<>(stringStringDoubleTriple.getFirst(),
|
||||
new Tuple2<>(stringStringDoubleTriple.getSecond(),
|
||||
stringStringDoubleTriple.getThird()));
|
||||
}
|
||||
});
|
||||
|
||||
JavaPairRDD<VocabWord, Tuple2<VocabWord, Float>> pairsVocab = pairs.mapToPair(
|
||||
new PairFunction<Tuple2<String, Tuple2<String, Float>>, VocabWord, Tuple2<VocabWord, Float>>() {
|
||||
@Override
|
||||
public Tuple2<VocabWord, Tuple2<VocabWord, Float>> call(
|
||||
Tuple2<String, Tuple2<String, Float>> stringTuple2Tuple2) throws Exception {
|
||||
VocabWord w1 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1());
|
||||
VocabWord w2 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1());
|
||||
return new Tuple2<>(w1, new Tuple2<>(w2, stringTuple2Tuple2._2()._2()));
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
JavaRDD<GloveChange> change =
|
||||
pairsVocab.map(new Function<Tuple2<VocabWord, Tuple2<VocabWord, Float>>, GloveChange>() {
|
||||
@Override
|
||||
public GloveChange call(
|
||||
Tuple2<VocabWord, Tuple2<VocabWord, Float>> vocabWordTuple2Tuple2)
|
||||
throws Exception {
|
||||
VocabWord w1 = vocabWordTuple2Tuple2._1();
|
||||
VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
|
||||
INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
|
||||
INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
|
||||
INDArray bias = gloveWeightLookupTable.getBias();
|
||||
double score = vocabWordTuple2Tuple2._2()._2();
|
||||
double xMax = gloveWeightLookupTable.getxMax();
|
||||
double maxCount = gloveWeightLookupTable.getMaxCount();
|
||||
//w1 * w2 + bias
|
||||
double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
|
||||
prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
|
||||
|
||||
double weight = FastMath.pow(Math.min(1.0, (score / maxCount)), xMax);
|
||||
|
||||
double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
|
||||
if (Double.isNaN(fDiff))
|
||||
fDiff = Nd4j.EPS_THRESHOLD;
|
||||
//amount of change
|
||||
double gradient = fDiff;
|
||||
|
||||
Pair<INDArray, Float> w1Update = update(gloveWeightLookupTable.getWeightAdaGrad(),
|
||||
gloveWeightLookupTable.getBiasAdaGrad(),
|
||||
gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(),
|
||||
w1, w1Vector, w2Vector, gradient);
|
||||
Pair<INDArray, Float> w2Update = update(gloveWeightLookupTable.getWeightAdaGrad(),
|
||||
gloveWeightLookupTable.getBiasAdaGrad(),
|
||||
gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(),
|
||||
w2, w2Vector, w1Vector, gradient);
|
||||
return new GloveChange(w1, w2, w1Update.getFirst(), w2Update.getFirst(),
|
||||
w1Update.getSecond(), w2Update.getSecond(), fDiff,
|
||||
gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient()
|
||||
.slice(w1.getIndex()),
|
||||
gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient()
|
||||
.slice(w2.getIndex()),
|
||||
gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient()
|
||||
.getDouble(w2.getIndex()),
|
||||
gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient()
|
||||
.getDouble(w1.getIndex()));
|
||||
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
|
||||
List<GloveChange> gloveChanges = change.collect();
|
||||
double error = 0.0;
|
||||
for (GloveChange change2 : gloveChanges) {
|
||||
change2.apply(gloveWeightLookupTable);
|
||||
error += change2.getError();
|
||||
}
|
||||
|
||||
|
||||
List l = pairsVocab.collect();
|
||||
Collections.shuffle(l);
|
||||
pairsVocab = sc.parallelizePairs(l);
|
||||
|
||||
log.info("Error at iteration " + i + " was " + error);
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,163 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.spark.models.embeddings.glove;
|
||||
|
||||
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
|
||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class GloveChange implements Serializable {
|
||||
private VocabWord w1, w2;
|
||||
private INDArray w1Update, w2Update;
|
||||
private double w1BiasUpdate, w2BiasUpdate;
|
||||
private double error;
|
||||
private INDArray w1History, w2History;
|
||||
private double w1BiasHistory, w2BiasHistory;
|
||||
|
||||
public GloveChange(VocabWord w1, VocabWord w2, INDArray w1Update, INDArray w2Update, double w1BiasUpdate,
|
||||
double w2BiasUpdate, double error, INDArray w1History, INDArray w2History, double w1BiasHistory,
|
||||
double w2BiasHistory) {
|
||||
this.w1 = w1;
|
||||
this.w2 = w2;
|
||||
this.w1Update = w1Update;
|
||||
this.w2Update = w2Update;
|
||||
this.w1BiasUpdate = w1BiasUpdate;
|
||||
this.w2BiasUpdate = w2BiasUpdate;
|
||||
this.error = error;
|
||||
this.w1History = w1History;
|
||||
this.w2History = w2History;
|
||||
this.w1BiasHistory = w1BiasHistory;
|
||||
this.w2BiasHistory = w2BiasHistory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply the changes to the table
|
||||
* @param table
|
||||
*/
|
||||
public void apply(GloveWeightLookupTable table) {
|
||||
table.getBias().putScalar(w1.getIndex(), table.getBias().getDouble(w1.getIndex()) - w1BiasUpdate);
|
||||
table.getBias().putScalar(w2.getIndex(), table.getBias().getDouble(w2.getIndex()) - w2BiasUpdate);
|
||||
table.getSyn0().slice(w1.getIndex()).subi(w1Update);
|
||||
table.getSyn0().slice(w2.getIndex()).subi(w2Update);
|
||||
table.getWeightAdaGrad().getHistoricalGradient().slice(w1.getIndex()).addi(w1History);
|
||||
table.getWeightAdaGrad().getHistoricalGradient().slice(w2.getIndex()).addi(w2History);
|
||||
table.getBiasAdaGrad().getHistoricalGradient().putScalar(w1.getIndex(),
|
||||
table.getBiasAdaGrad().getHistoricalGradient().getDouble(w1.getIndex()) + w1BiasHistory);
|
||||
table.getBiasAdaGrad().getHistoricalGradient().putScalar(w2.getIndex(),
|
||||
table.getBiasAdaGrad().getHistoricalGradient().getDouble(w2.getIndex()) + w1BiasHistory);
|
||||
|
||||
}
|
||||
|
||||
public INDArray getW1History() {
|
||||
return w1History;
|
||||
}
|
||||
|
||||
public void setW1History(INDArray w1History) {
|
||||
this.w1History = w1History;
|
||||
}
|
||||
|
||||
public INDArray getW2History() {
|
||||
return w2History;
|
||||
}
|
||||
|
||||
public void setW2History(INDArray w2History) {
|
||||
this.w2History = w2History;
|
||||
}
|
||||
|
||||
public double getW1BiasHistory() {
|
||||
return w1BiasHistory;
|
||||
}
|
||||
|
||||
public void setW1BiasHistory(double w1BiasHistory) {
|
||||
this.w1BiasHistory = w1BiasHistory;
|
||||
}
|
||||
|
||||
public double getW2BiasHistory() {
|
||||
return w2BiasHistory;
|
||||
}
|
||||
|
||||
public void setW2BiasHistory(double w2BiasHistory) {
|
||||
this.w2BiasHistory = w2BiasHistory;
|
||||
}
|
||||
|
||||
public VocabWord getW1() {
|
||||
return w1;
|
||||
}
|
||||
|
||||
public void setW1(VocabWord w1) {
|
||||
this.w1 = w1;
|
||||
}
|
||||
|
||||
public VocabWord getW2() {
|
||||
return w2;
|
||||
}
|
||||
|
||||
public void setW2(VocabWord w2) {
|
||||
this.w2 = w2;
|
||||
}
|
||||
|
||||
public INDArray getW1Update() {
|
||||
return w1Update;
|
||||
}
|
||||
|
||||
public void setW1Update(INDArray w1Update) {
|
||||
this.w1Update = w1Update;
|
||||
}
|
||||
|
||||
public INDArray getW2Update() {
|
||||
return w2Update;
|
||||
}
|
||||
|
||||
public void setW2Update(INDArray w2Update) {
|
||||
this.w2Update = w2Update;
|
||||
}
|
||||
|
||||
public double getW1BiasUpdate() {
|
||||
return w1BiasUpdate;
|
||||
}
|
||||
|
||||
public void setW1BiasUpdate(double w1BiasUpdate) {
|
||||
this.w1BiasUpdate = w1BiasUpdate;
|
||||
}
|
||||
|
||||
public double getW2BiasUpdate() {
|
||||
return w2BiasUpdate;
|
||||
}
|
||||
|
||||
public void setW2BiasUpdate(double w2BiasUpdate) {
|
||||
this.w2BiasUpdate = w2BiasUpdate;
|
||||
}
|
||||
|
||||
public double getError() {
|
||||
return error;
|
||||
}
|
||||
|
||||
public void setError(double error) {
|
||||
this.error = error;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return w1.getIndex() + "," + w2.getIndex() + " error " + error;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,171 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.spark.models.embeddings.glove;
|
||||
|
||||
import org.apache.spark.broadcast.Broadcast;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.common.primitives.CounterMap;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class GloveParam implements Serializable {
|
||||
|
||||
private int vectorLength;
|
||||
private boolean useAdaGrad;
|
||||
private double lr;
|
||||
private Random gen;
|
||||
private double negative;
|
||||
private double xMax;
|
||||
private double maxCount;
|
||||
private Broadcast<CounterMap<String, String>> coOccurrenceCounts;
|
||||
|
||||
public GloveParam(int vectorLength, boolean useAdaGrad, double lr, Random gen, double negative, double xMax,
|
||||
double maxCount, Broadcast<CounterMap<String, String>> coOccurrenceCounts) {
|
||||
this.vectorLength = vectorLength;
|
||||
this.useAdaGrad = useAdaGrad;
|
||||
this.lr = lr;
|
||||
this.gen = gen;
|
||||
this.negative = negative;
|
||||
this.xMax = xMax;
|
||||
this.maxCount = maxCount;
|
||||
this.coOccurrenceCounts = coOccurrenceCounts;
|
||||
}
|
||||
|
||||
public int getVectorLength() {
|
||||
return vectorLength;
|
||||
}
|
||||
|
||||
public void setVectorLength(int vectorLength) {
|
||||
this.vectorLength = vectorLength;
|
||||
}
|
||||
|
||||
public boolean isUseAdaGrad() {
|
||||
return useAdaGrad;
|
||||
}
|
||||
|
||||
public void setUseAdaGrad(boolean useAdaGrad) {
|
||||
this.useAdaGrad = useAdaGrad;
|
||||
}
|
||||
|
||||
public double getLr() {
|
||||
return lr;
|
||||
}
|
||||
|
||||
public void setLr(double lr) {
|
||||
this.lr = lr;
|
||||
}
|
||||
|
||||
public Random getGen() {
|
||||
return gen;
|
||||
}
|
||||
|
||||
public void setGen(Random gen) {
|
||||
this.gen = gen;
|
||||
}
|
||||
|
||||
public double getNegative() {
|
||||
return negative;
|
||||
}
|
||||
|
||||
public void setNegative(double negative) {
|
||||
this.negative = negative;
|
||||
}
|
||||
|
||||
public double getxMax() {
|
||||
return xMax;
|
||||
}
|
||||
|
||||
public void setxMax(double xMax) {
|
||||
this.xMax = xMax;
|
||||
}
|
||||
|
||||
public double getMaxCount() {
|
||||
return maxCount;
|
||||
}
|
||||
|
||||
public void setMaxCount(double maxCount) {
|
||||
this.maxCount = maxCount;
|
||||
}
|
||||
|
||||
public Broadcast<CounterMap<String, String>> getCoOccurrenceCounts() {
|
||||
return coOccurrenceCounts;
|
||||
}
|
||||
|
||||
public void setCoOccurrenceCounts(Broadcast<CounterMap<String, String>> coOccurrenceCounts) {
|
||||
this.coOccurrenceCounts = coOccurrenceCounts;
|
||||
}
|
||||
|
||||
|
||||
public static class Builder {
|
||||
private int vectorLength = 300;
|
||||
private boolean useAdaGrad = true;
|
||||
private double lr = 0.025;
|
||||
private Random gen;
|
||||
private double negative = 5;
|
||||
private double xMax = 0.75;
|
||||
private double maxCount = 100;
|
||||
private Broadcast<CounterMap<String, String>> coOccurrenceCounts;
|
||||
|
||||
public Builder vectorLength(int vectorLength) {
|
||||
this.vectorLength = vectorLength;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder useAdaGrad(boolean useAdaGrad) {
|
||||
this.useAdaGrad = useAdaGrad;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder lr(double lr) {
|
||||
this.lr = lr;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder gen(Random gen) {
|
||||
this.gen = gen;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder negative(double negative) {
|
||||
this.negative = negative;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder xMax(double xMax) {
|
||||
this.xMax = xMax;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder maxCount(double maxCount) {
|
||||
this.maxCount = maxCount;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder coOccurrenceCounts(Broadcast<CounterMap<String, String>> coOccurrenceCounts) {
|
||||
this.coOccurrenceCounts = coOccurrenceCounts;
|
||||
return this;
|
||||
}
|
||||
|
||||
public GloveParam build() {
|
||||
return new GloveParam(vectorLength, useAdaGrad, lr, gen, negative, xMax, maxCount, coOccurrenceCounts);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,48 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.spark.models.embeddings.glove;
|
||||
|
||||
import org.apache.spark.api.java.function.Function;
|
||||
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
|
||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||
import org.nd4j.common.primitives.Triple;
|
||||
|
||||
|
||||
/**
|
||||
* Base line glove performer
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class GlovePerformer implements Function<Triple<VocabWord, VocabWord, Double>, GloveChange> {
|
||||
|
||||
|
||||
public final static String NAME_SPACE = "org.deeplearning4j.scaleout.perform.models.glove";
|
||||
public final static String VECTOR_LENGTH = NAME_SPACE + ".length";
|
||||
public final static String ALPHA = NAME_SPACE + ".alpha";
|
||||
public final static String X_MAX = NAME_SPACE + ".xmax";
|
||||
public final static String MAX_COUNT = NAME_SPACE + ".maxcount";
|
||||
private GloveWeightLookupTable table;
|
||||
|
||||
public GlovePerformer(GloveWeightLookupTable table) {
|
||||
this.table = table;
|
||||
}
|
||||
|
||||
@Override
|
||||
public GloveChange call(Triple<VocabWord, VocabWord, Double> pair) throws Exception {
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.spark.models.embeddings.glove;
|
||||
|
||||
import org.apache.spark.api.java.function.Function;
|
||||
import org.apache.spark.broadcast.Broadcast;
|
||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.nd4j.common.primitives.Triple;
|
||||
|
||||
/**
|
||||
* Convert string to vocab words
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class VocabWordPairs implements Function<Triple<String, String, Double>, Triple<VocabWord, VocabWord, Double>> {
|
||||
private Broadcast<VocabCache<VocabWord>> vocab;
|
||||
|
||||
public VocabWordPairs(Broadcast<VocabCache<VocabWord>> vocab) {
|
||||
this.vocab = vocab;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Triple<VocabWord, VocabWord, Double> call(Triple<String, String, Double> v1) throws Exception {
|
||||
return new Triple<>((VocabWord) vocab.getValue().wordFor(v1.getFirst()),
|
||||
(VocabWord) vocab.getValue().wordFor(v1.getSecond()), v1.getThird());
|
||||
}
|
||||
}
|
|
@ -1,91 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.spark.models.embeddings.glove.cooccurrences;
|
||||
|
||||
import org.apache.spark.api.java.function.Function;
|
||||
import org.apache.spark.broadcast.Broadcast;
|
||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.CounterMap;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
/**
|
||||
* Calculate co occurrences based on tokens
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class CoOccurrenceCalculator implements Function<Pair<List<String>, AtomicLong>, CounterMap<String, String>> {
|
||||
private boolean symmetric = false;
|
||||
private Broadcast<VocabCache<VocabWord>> vocab;
|
||||
private int windowSize = 5;
|
||||
|
||||
public CoOccurrenceCalculator(boolean symmetric, Broadcast<VocabCache<VocabWord>> vocab, int windowSize) {
|
||||
this.symmetric = symmetric;
|
||||
this.vocab = vocab;
|
||||
this.windowSize = windowSize;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public CounterMap<String, String> call(Pair<List<String>, AtomicLong> pair) throws Exception {
|
||||
List<String> sentence = pair.getFirst();
|
||||
CounterMap<String, String> coOCurreneCounts = new CounterMap<>();
|
||||
VocabCache vocab = this.vocab.value();
|
||||
for (int i = 0; i < sentence.size(); i++) {
|
||||
int wordIdx = vocab.indexOf(sentence.get(i));
|
||||
String w1 = ((VocabWord) vocab.wordFor(sentence.get(i))).getWord();
|
||||
|
||||
if (wordIdx < 0) // || w1.equals(Glove.UNK))
|
||||
continue;
|
||||
int windowStop = Math.min(i + windowSize + 1, sentence.size());
|
||||
for (int j = i; j < windowStop; j++) {
|
||||
int otherWord = vocab.indexOf(sentence.get(j));
|
||||
String w2 = ((VocabWord) vocab.wordFor(sentence.get(j))).getWord();
|
||||
if (vocab.indexOf(sentence.get(j)) < 0) // || w2.equals(Glove.UNK))
|
||||
continue;
|
||||
|
||||
if (otherWord == wordIdx)
|
||||
continue;
|
||||
if (wordIdx < otherWord) {
|
||||
coOCurreneCounts.incrementCount(sentence.get(i), sentence.get(j),
|
||||
(float) (1.0 / (j - i + Nd4j.EPS_THRESHOLD)));
|
||||
if (symmetric)
|
||||
coOCurreneCounts.incrementCount(sentence.get(j), sentence.get(i),
|
||||
(float) (1.0 / (j - i + Nd4j.EPS_THRESHOLD)));
|
||||
|
||||
|
||||
|
||||
} else {
|
||||
float coCount = (float) (1.0 / (j - i + Nd4j.EPS_THRESHOLD));
|
||||
coOCurreneCounts.incrementCount(sentence.get(j), sentence.get(i), (float) coCount);
|
||||
if (symmetric)
|
||||
coOCurreneCounts.incrementCount(sentence.get(i), sentence.get(j),
|
||||
(float) (1.0 / (j - i + Nd4j.EPS_THRESHOLD)));
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
return coOCurreneCounts;
|
||||
}
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.spark.models.embeddings.glove.cooccurrences;
|
||||
|
||||
import org.apache.spark.api.java.function.Function2;
|
||||
import org.nd4j.common.primitives.CounterMap;
|
||||
|
||||
|
||||
/**
|
||||
* Co occurrence count reduction
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class CoOccurrenceCounts implements
|
||||
Function2<CounterMap<String, String>, CounterMap<String, String>, CounterMap<String, String>> {
|
||||
|
||||
|
||||
@Override
|
||||
public CounterMap<String, String> call(CounterMap<String, String> v1, CounterMap<String, String> v2)
|
||||
throws Exception {
|
||||
v1.incrementAll(v2);
|
||||
return v1;
|
||||
}
|
||||
}
|
|
@ -1,62 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.spark.models.embeddings.glove;
|
||||
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.function.Function;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
||||
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
|
||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.deeplearning4j.spark.text.BaseSparkTest;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
import java.util.Collection;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
/**
|
||||
* Created by agibsonccc on 1/31/15.
|
||||
*/
|
||||
@Ignore
|
||||
public class GloveTest extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
public void testGlove() throws Exception {
|
||||
Glove glove = new Glove(true, 5, 100);
|
||||
JavaRDD<String> corpus = sc.textFile(new ClassPathResource("big/raw_sentences.txt").getFile().getAbsolutePath())
|
||||
.map(new Function<String, String>() {
|
||||
@Override
|
||||
public String call(String s) throws Exception {
|
||||
return s.toLowerCase();
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
Pair<VocabCache<VocabWord>, GloveWeightLookupTable> table = glove.train(corpus);
|
||||
WordVectors vectors = WordVectorSerializer
|
||||
.fromPair(new Pair<>((InMemoryLookupTable) table.getSecond(), (VocabCache) table.getFirst()));
|
||||
Collection<String> words = vectors.wordsNearest("day", 20);
|
||||
assertTrue(words.contains("week"));
|
||||
}
|
||||
|
||||
}
|
|
@ -26,6 +26,7 @@ import org.deeplearning4j.nn.transferlearning.TransferLearning;
|
|||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.ui.model.stats.StatsListener;
|
||||
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
|
||||
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
@ -42,8 +43,10 @@ import java.io.IOException;
|
|||
*/
|
||||
public class TestTransferStatsCollection extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90_000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test() throws IOException {
|
||||
|
@ -62,9 +65,7 @@ public class TestTransferStatsCollection extends BaseDL4JTest {
|
|||
new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
|
||||
.setFeatureExtractor(0).build();
|
||||
|
||||
File dir = testDir.newFolder();
|
||||
File f = new File(dir, "dl4jTestTransferStatsCollection.bin");
|
||||
net2.setListeners(new StatsListener(new FileStatsStorage(f)));
|
||||
net2.setListeners(new StatsListener(new InMemoryStatsStorage()));
|
||||
|
||||
//Previosuly: failed on frozen layers
|
||||
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));
|
||||
|
|
|
@ -88,7 +88,7 @@ namespace sd {
|
|||
cudaFree(_allocationPointer);
|
||||
|
||||
if (_scalarPointer != nullptr)
|
||||
cudaFree(_scalarPointer);
|
||||
cudaFreeHost(_scalarPointer);
|
||||
|
||||
if (_allocationPointer != nullptr)
|
||||
cudaFree(_reductionPointer);
|
||||
|
|
|
@ -243,9 +243,6 @@ __host__ void ReduceBoolFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStrea
|
|||
int *dimension, int dimensionLength,
|
||||
void *reductionPointer,
|
||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
|
||||
|
||||
nd4j_printf("Step A%i\n", -1);
|
||||
|
||||
if(shape::isEmpty(hXShapeInfo)) {
|
||||
|
||||
if(shape::isEmpty(hZShapeInfo))
|
||||
|
|
|
@ -515,8 +515,8 @@ BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT cudaDecodeBitmapGeneric, (dim3 &
|
|||
|
||||
template <bool storeSum, bool isNP2>
|
||||
__host__ void prescanLauncher(dim3 &blocks, dim3 &threads, int shmem, cudaStream_t *stream, int *g_odata, const int *g_idata, int *g_blockSums, int n, int blockIndex, int baseIndex) {
|
||||
//printf("Prescan grid: <%i/%i/%i>; threads: <%i/%i/%i>; shareMemSize: %i\n", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, shmem);
|
||||
prescan<storeSum, isNP2><<<blocks, threads, shmem, *stream>>>(g_odata, g_idata, g_blockSums, n, blockIndex, baseIndex);
|
||||
sd::DebugHelper::checkErrorCode(stream, "prescan(...) failed");
|
||||
};
|
||||
|
||||
template <typename S, typename T>
|
||||
|
|
|
@ -41,8 +41,12 @@ namespace sd {
|
|||
else
|
||||
numThreads = sd::floorPow2(numElements);
|
||||
|
||||
numThreads = sd::math::nd4j_max<int>(1, numThreads);
|
||||
|
||||
int numEltsPerBlock = numThreads * 2;
|
||||
|
||||
|
||||
|
||||
// if this is a non-power-of-2 array, the last block will be non-full
|
||||
// compute the smallest power of 2 able to compute its scan.
|
||||
int numEltsLastBlock =
|
||||
|
@ -102,8 +106,6 @@ namespace sd {
|
|||
} else {
|
||||
sd::prescanLauncher<false, true>(grid, threads, sharedMemSize, stream, dZ, dX, 0, numElements, 0, 0);
|
||||
}
|
||||
|
||||
sd::DebugHelper::checkErrorCode(stream, "prescanArray(...) failed");
|
||||
}
|
||||
|
||||
static void encodeThresholdP2Int_(void **prs, int *dx, Nd4jLong N, int *dz) {
|
||||
|
|
|
@ -119,7 +119,7 @@ TEST_F(CudaBasicsTests1, TestPairwise_1) {
|
|||
z.tickWriteHost();
|
||||
|
||||
for (int e = 0; e < z.lengthOf(); e++) {
|
||||
nd4j_printf("step %i\n", e);
|
||||
//nd4j_printf("step %i\n", e);
|
||||
ASSERT_NEAR(exp.e<double>(e), z.e<double>(e), 1e-5);
|
||||
}
|
||||
}
|
||||
|
@ -2822,7 +2822,7 @@ TEST_F(CudaBasicsTests1, execSummaryStats_2) {
|
|||
// delete cuda stream
|
||||
cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult);
|
||||
}
|
||||
|
||||
/*
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(CudaBasicsTests1, execSummaryStats_3) {
|
||||
|
||||
|
@ -2876,6 +2876,7 @@ TEST_F(CudaBasicsTests1, execSummaryStats_3) {
|
|||
// delete cuda stream
|
||||
cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult);
|
||||
}
|
||||
*/
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(CudaBasicsTests1, execSummaryStatsScalar_1) {
|
||||
|
|
|
@ -1054,6 +1054,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
|
|||
ASSERT_TRUE(testData.equalsTo(result));
|
||||
}
|
||||
|
||||
/*
|
||||
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
|
||||
|
@ -1114,6 +1115,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) {
|
|||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<float>('c', {1, 3, 3, 1});
|
||||
|
@ -1530,6 +1532,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) {
|
|||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
}
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {
|
||||
|
|
|
@ -826,6 +826,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
case FLOAT:
|
||||
return ((FloatIndexer) indexer).get(i);
|
||||
case UINT32:
|
||||
return ((UIntIndexer) indexer).get(i);
|
||||
case INT:
|
||||
return ((IntIndexer) indexer).get(i);
|
||||
case BFLOAT16:
|
||||
|
@ -866,10 +867,11 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
return (long) ((Bfloat16Indexer) indexer).get(i);
|
||||
case HALF:
|
||||
return (long) ((HalfIndexer) indexer).get( i);
|
||||
case UINT64:
|
||||
case UINT64: //Fall through
|
||||
case LONG:
|
||||
return ((LongIndexer) indexer).get(i);
|
||||
case UINT32:
|
||||
return (long) ((UIntIndexer) indexer).get(i);
|
||||
case INT:
|
||||
return (long) ((IntIndexer) indexer).get(i);
|
||||
case UINT16:
|
||||
|
@ -906,6 +908,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
case BOOL:
|
||||
return (short) (((BooleanIndexer) indexer).get(i) ? 1 : 0);
|
||||
case UINT32:
|
||||
return (short) ((UIntIndexer)indexer).get(i);
|
||||
case INT:
|
||||
return (short) ((IntIndexer) indexer).get(i);
|
||||
case UINT16:
|
||||
|
@ -943,6 +946,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
case BOOL:
|
||||
return ((BooleanIndexer) indexer).get(i) ? 1.f : 0.f;
|
||||
case UINT32:
|
||||
return (float) ((UIntIndexer)indexer).get(i);
|
||||
case INT:
|
||||
return (float) ((IntIndexer) indexer).get(i);
|
||||
case UINT16:
|
||||
|
@ -957,7 +961,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
return (float) ((UByteIndexer) indexer).get(i);
|
||||
case BYTE:
|
||||
return (float) ((ByteIndexer) indexer).get(i);
|
||||
case UINT64:
|
||||
case UINT64: //Fall through
|
||||
case LONG:
|
||||
return (float) ((LongIndexer) indexer).get(i);
|
||||
case FLOAT:
|
||||
|
@ -978,6 +982,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
case BOOL:
|
||||
return ((BooleanIndexer) indexer).get(i) ? 1 : 0;
|
||||
case UINT32:
|
||||
return (int)((UIntIndexer) indexer).get(i);
|
||||
case INT:
|
||||
return ((IntIndexer) indexer).get(i);
|
||||
case BFLOAT16:
|
||||
|
@ -992,7 +997,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
return ((UByteIndexer) indexer).get(i);
|
||||
case BYTE:
|
||||
return ((ByteIndexer) indexer).get(i);
|
||||
case UINT64:
|
||||
case UINT64: //Fall through
|
||||
case LONG:
|
||||
return (int) ((LongIndexer) indexer).get(i);
|
||||
case FLOAT:
|
||||
|
@ -1058,6 +1063,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
((ShortIndexer) indexer).put(i, (short) element);
|
||||
break;
|
||||
case UINT32:
|
||||
((UIntIndexer) indexer).put(i, (long)element);
|
||||
break;
|
||||
case INT:
|
||||
((IntIndexer) indexer).put(i, (int) element);
|
||||
break;
|
||||
|
@ -1104,6 +1111,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
((ShortIndexer) indexer).put(i, (short) element);
|
||||
break;
|
||||
case UINT32:
|
||||
((UIntIndexer) indexer).put(i, (long)element);
|
||||
break;
|
||||
case INT:
|
||||
((IntIndexer) indexer).put(i, (int) element);
|
||||
break;
|
||||
|
@ -1150,10 +1159,12 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
((ShortIndexer) indexer).put(i, (short) element);
|
||||
break;
|
||||
case UINT32:
|
||||
((UIntIndexer) indexer).put(i, element);
|
||||
break;
|
||||
case INT:
|
||||
((IntIndexer) indexer).put(i, element);
|
||||
break;
|
||||
case UINT64:
|
||||
case UINT64: //Fall through
|
||||
case LONG:
|
||||
((LongIndexer) indexer).put(i, element);
|
||||
break;
|
||||
|
@ -1195,8 +1206,10 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
case SHORT:
|
||||
((ShortIndexer) indexer).put(i, element ? (short) 1 : (short) 0);
|
||||
break;
|
||||
case INT:
|
||||
case UINT32:
|
||||
((UIntIndexer) indexer).put(i, element ? 1 : 0);
|
||||
break;
|
||||
case INT:
|
||||
((IntIndexer) indexer).put(i, element ? 1 : 0);
|
||||
break;
|
||||
case UINT64:
|
||||
|
@ -1242,6 +1255,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
((ShortIndexer) indexer).put(i, (short) element);
|
||||
break;
|
||||
case UINT32:
|
||||
((UIntIndexer) indexer).put(i, element);
|
||||
break;
|
||||
case INT:
|
||||
((IntIndexer) indexer).put(i, (int) element);
|
||||
break;
|
||||
|
|
|
@ -324,6 +324,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
indexer = FloatIndexer.create((FloatPointer) pointer);
|
||||
break;
|
||||
case UINT32:
|
||||
this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer();
|
||||
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||
break;
|
||||
case INT:
|
||||
this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer();
|
||||
indexer = IntIndexer.create((IntPointer) pointer);
|
||||
|
@ -336,7 +339,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer();
|
||||
indexer = HalfIndexer.create((ShortPointer) pointer);
|
||||
break;
|
||||
case UINT64:
|
||||
case UINT64: //Fall through
|
||||
case LONG:
|
||||
this.pointer = new CudaPointer(hostPointer, length, 0).asLongPointer();
|
||||
indexer = LongIndexer.create((LongPointer) pointer);
|
||||
|
@ -501,6 +504,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
indexer = FloatIndexer.create((FloatPointer) pointer);
|
||||
break;
|
||||
case UINT32:
|
||||
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer();
|
||||
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||
break;
|
||||
case INT:
|
||||
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer();
|
||||
indexer = IntIndexer.create((IntPointer) pointer);
|
||||
|
@ -513,7 +519,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer();
|
||||
indexer = HalfIndexer.create((ShortPointer) pointer);
|
||||
break;
|
||||
case UINT64:
|
||||
case UINT64: //Fall through
|
||||
case LONG:
|
||||
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asLongPointer();
|
||||
indexer = LongIndexer.create((LongPointer) pointer);
|
||||
|
|
|
@ -121,6 +121,24 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
|||
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();
|
||||
|
||||
setIndexer(ByteIndexer.create((BytePointer) pointer));
|
||||
} else if(dataType() == DataType.FLOAT16){
|
||||
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
|
||||
setIndexer(HalfIndexer.create((ShortPointer) pointer));
|
||||
} else if(dataType() == DataType.BFLOAT16){
|
||||
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
|
||||
setIndexer(Bfloat16Indexer.create((ShortPointer) pointer));
|
||||
} else if(dataType() == DataType.BOOL){
|
||||
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBoolPointer();
|
||||
setIndexer(BooleanIndexer.create((BooleanPointer) pointer));
|
||||
} else if(dataType() == DataType.UINT16){
|
||||
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
|
||||
setIndexer(UShortIndexer.create((ShortPointer) pointer));
|
||||
} else if(dataType() == DataType.UINT32){
|
||||
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
|
||||
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||
} else if (dataType() == DataType.UINT64) {
|
||||
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
|
||||
setIndexer(LongIndexer.create((LongPointer) pointer));
|
||||
}
|
||||
|
||||
Nd4j.getDeallocatorService().pickObject(this);
|
||||
|
@ -336,15 +354,13 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
|||
} else if (dataType() == DataType.UINT32) {
|
||||
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
|
||||
|
||||
// FIXME: we need unsigned indexer here
|
||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
||||
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||
|
||||
if (initialize)
|
||||
fillPointerWithZero();
|
||||
} else if (dataType() == DataType.UINT64) {
|
||||
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
|
||||
|
||||
// FIXME: we need unsigned indexer here
|
||||
setIndexer(LongIndexer.create((LongPointer) pointer));
|
||||
|
||||
if (initialize)
|
||||
|
@ -500,7 +516,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
|||
|
||||
// FIXME: need unsigned indexer here
|
||||
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
|
||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
||||
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||
|
||||
} else if (dataType() == DataType.UINT64) {
|
||||
attached = true;
|
||||
|
|
|
@ -8395,6 +8395,25 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
assertEquals(e, z);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateBufferFromByteBuffer(){
|
||||
|
||||
for(DataType dt : DataType.values()){
|
||||
if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
|
||||
continue;
|
||||
// System.out.println(dt);
|
||||
|
||||
int lengthBytes = 256;
|
||||
int lengthElements = lengthBytes / dt.width();
|
||||
ByteBuffer bb = ByteBuffer.allocateDirect(lengthBytes);
|
||||
|
||||
DataBuffer db = Nd4j.createBuffer(bb, dt, lengthElements, 0);
|
||||
INDArray arr = Nd4j.create(db, new long[]{lengthElements});
|
||||
|
||||
arr.toStringFull();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue