Merge remote-tracking branch 'origin/master'
commit
6ce620709a
|
@ -24,7 +24,6 @@ import org.apache.commons.io.FileUtils;
|
|||
import org.apache.commons.io.IOUtils;
|
||||
import org.apache.commons.io.LineIterator;
|
||||
import org.apache.commons.io.output.CloseShieldOutputStream;
|
||||
import org.deeplearning4j.exception.DL4JException;
|
||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||
|
@ -52,7 +51,6 @@ import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
|||
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
import org.deeplearning4j.util.DL4JFileUtils;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.compression.impl.NoOp;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
|
@ -68,8 +66,6 @@ import org.nd4j.util.OneTimeLogger;
|
|||
|
||||
import java.io.*;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
@ -78,6 +74,80 @@ import java.util.zip.*;
|
|||
/**
|
||||
* This is utility class, providing various methods for WordVectors serialization
|
||||
*
|
||||
* List of available serialization methods (please keep this list consistent with source code):
|
||||
*
|
||||
* <ul>
|
||||
* <li>Serializers for Word2Vec:</li>
|
||||
* {@link #writeWordVectors(WeightLookupTable, File)}
|
||||
* {@link #writeWordVectors(WeightLookupTable, OutputStream)}
|
||||
* {@link #writeWord2VecModel(Word2Vec, File)}
|
||||
* {@link #writeWord2VecModel(Word2Vec, String)}
|
||||
* {@link #writeWord2VecModel(Word2Vec, OutputStream)}
|
||||
*
|
||||
* <li>Deserializers for Word2Vec:</li>
|
||||
* {@link #readWord2VecModel(File)}
|
||||
* {@link #readWord2VecModel(String)}
|
||||
* {@link #readWord2VecModel(File, boolean)}
|
||||
* {@link #readWord2VecModel(String, boolean)}
|
||||
* {@link #readAsBinaryNoLineBreaks(File)}
|
||||
* {@link #readAsBinary(File)}
|
||||
* {@link #readAsCsv(File)}
|
||||
* {@link #readBinaryModel(File, boolean, boolean)}
|
||||
* {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)}
|
||||
* {@link #readWord2Vec(String, boolean)}
|
||||
* {@link #readWord2Vec(File, boolean)}
|
||||
* {@link #readWord2Vec(InputStream, boolean)}
|
||||
*
|
||||
* <li>Serializers for ParaVec:</li>
|
||||
* {@link #writeParagraphVectors(ParagraphVectors, File)}
|
||||
* {@link #writeParagraphVectors(ParagraphVectors, String)}
|
||||
* {@link #writeParagraphVectors(ParagraphVectors, OutputStream)}
|
||||
*
|
||||
* <li>Deserializers for ParaVec:</li>
|
||||
* {@link #readParagraphVectors(File)}
|
||||
* {@link #readParagraphVectors(String)}
|
||||
* {@link #readParagraphVectors(InputStream)}
|
||||
*
|
||||
* <li>Serializers for GloVe:</li>
|
||||
* {@link #writeWordVectors(Glove, File)}
|
||||
* {@link #writeWordVectors(Glove, String)}
|
||||
* {@link #writeWordVectors(Glove, OutputStream)}
|
||||
*
|
||||
* <li>Adapters</li>
|
||||
* {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
|
||||
* {@link #fromPair(Pair)}
|
||||
* {@link #loadTxt(File)}
|
||||
*
|
||||
* <li>Serializers to tSNE format</li>
|
||||
* {@link #writeTsneFormat(Glove, INDArray, File)}
|
||||
* {@link #writeTsneFormat(Word2Vec, INDArray, File)}
|
||||
*
|
||||
* <li>FastText serializer:</li>
|
||||
* {@link #writeWordVectors(FastText, File)}
|
||||
*
|
||||
* <li>FastText deserializer:</li>
|
||||
* {@link #readWordVectors(File)}
|
||||
*
|
||||
* <li>SequenceVectors serializers:</li>
|
||||
* {@link #writeSequenceVectors(SequenceVectors, OutputStream)}
|
||||
* {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, File)}
|
||||
* {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, String)}
|
||||
* {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, OutputStream)}
|
||||
* {@link #writeLookupTable(WeightLookupTable, File)}
|
||||
* {@link #writeVocabCache(VocabCache, File)}
|
||||
* {@link #writeVocabCache(VocabCache, OutputStream)}
|
||||
*
|
||||
* <li>SequenceVectors deserializers:</li>
|
||||
* {@link #readSequenceVectors(File, boolean)}
|
||||
* {@link #readSequenceVectors(String, boolean)}
|
||||
* {@link #readSequenceVectors(SequenceElementFactory, File)}
|
||||
* {@link #readSequenceVectors(InputStream, boolean)}
|
||||
* {@link #readSequenceVectors(SequenceElementFactory, InputStream)}
|
||||
* {@link #readLookupTable(File)}
|
||||
* {@link #readLookupTable(InputStream)}
|
||||
*
|
||||
* </ul>
|
||||
*
|
||||
* @author Adam Gibson
|
||||
* @author raver119
|
||||
* @author alexander@skymind.io
|
||||
|
@ -97,7 +167,7 @@ public class WordVectorSerializer {
|
|||
* @throws IOException
|
||||
* @throws NumberFormatException
|
||||
*/
|
||||
private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
|
||||
/*private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
|
||||
InMemoryLookupTable lookupTable;
|
||||
VocabCache cache;
|
||||
INDArray syn0;
|
||||
|
@ -142,7 +212,7 @@ public class WordVectorSerializer {
|
|||
ret.setLookupTable(lookupTable);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}*/
|
||||
|
||||
/**
|
||||
* Read a binary word2vec file.
|
||||
|
@ -173,8 +243,8 @@ public class WordVectorSerializer {
|
|||
try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName())
|
||||
? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile));
|
||||
DataInputStream dis = new DataInputStream(bis)) {
|
||||
words = Integer.parseInt(readString(dis));
|
||||
size = Integer.parseInt(readString(dis));
|
||||
words = Integer.parseInt(ReadHelper.readString(dis));
|
||||
size = Integer.parseInt(ReadHelper.readString(dis));
|
||||
syn0 = Nd4j.create(words, size);
|
||||
cache = new AbstractCache<>();
|
||||
|
||||
|
@ -188,11 +258,11 @@ public class WordVectorSerializer {
|
|||
float[] vector = new float[size];
|
||||
for (int i = 0; i < words; i++) {
|
||||
|
||||
word = readString(dis);
|
||||
word = ReadHelper.readString(dis);
|
||||
log.trace("Loading " + word + " with word " + i);
|
||||
|
||||
for (int j = 0; j < size; j++) {
|
||||
vector[j] = readFloat(dis);
|
||||
vector[j] = ReadHelper.readFloat(dis);
|
||||
}
|
||||
|
||||
if (cache.containsWord(word))
|
||||
|
@ -236,64 +306,6 @@ public class WordVectorSerializer {
|
|||
|
||||
}
|
||||
|
||||
/**
|
||||
* Read a float from a data input stream Credit to:
|
||||
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
|
||||
*
|
||||
* @param is
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
public static float readFloat(InputStream is) throws IOException {
|
||||
byte[] bytes = new byte[4];
|
||||
is.read(bytes);
|
||||
return getFloat(bytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read a string from a data input stream Credit to:
|
||||
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
|
||||
*
|
||||
* @param b
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
public static float getFloat(byte[] b) {
|
||||
int accum = 0;
|
||||
accum = accum | (b[0] & 0xff) << 0;
|
||||
accum = accum | (b[1] & 0xff) << 8;
|
||||
accum = accum | (b[2] & 0xff) << 16;
|
||||
accum = accum | (b[3] & 0xff) << 24;
|
||||
return Float.intBitsToFloat(accum);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read a string from a data input stream Credit to:
|
||||
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
|
||||
*
|
||||
* @param dis
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
public static String readString(DataInputStream dis) throws IOException {
|
||||
byte[] bytes = new byte[MAX_SIZE];
|
||||
byte b = dis.readByte();
|
||||
int i = -1;
|
||||
StringBuilder sb = new StringBuilder();
|
||||
while (b != 32 && b != 10) {
|
||||
i++;
|
||||
bytes[i] = b;
|
||||
b = dis.readByte();
|
||||
if (i == 49) {
|
||||
sb.append(new String(bytes, "UTF-8"));
|
||||
i = -1;
|
||||
bytes = new byte[MAX_SIZE];
|
||||
}
|
||||
}
|
||||
sb.append(new String(bytes, 0, i + 1, "UTF-8"));
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* This method writes word vectors to the given path.
|
||||
* Please note: this method doesn't load whole vocab/lookupTable into memory, so it's able to process large vocabularies served over network.
|
||||
|
@ -355,7 +367,7 @@ public class WordVectorSerializer {
|
|||
val builder = new StringBuilder();
|
||||
|
||||
val l = element.getLabel();
|
||||
builder.append(encodeB64(l)).append(" ");
|
||||
builder.append(ReadHelper.encodeB64(l)).append(" ");
|
||||
val vec = lookupTable.vector(element.getLabel());
|
||||
for (int i = 0; i < vec.length(); i++) {
|
||||
builder.append(vec.getDouble(i));
|
||||
|
@ -518,7 +530,7 @@ public class WordVectorSerializer {
|
|||
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
|
||||
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
||||
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
||||
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
|
||||
StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
|
||||
for (int code : word.getCodes()) {
|
||||
builder.append(code).append(" ");
|
||||
}
|
||||
|
@ -536,7 +548,7 @@ public class WordVectorSerializer {
|
|||
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
|
||||
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
||||
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
||||
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
|
||||
StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
|
||||
for (int point : word.getPoints()) {
|
||||
builder.append(point).append(" ");
|
||||
}
|
||||
|
@ -554,7 +566,7 @@ public class WordVectorSerializer {
|
|||
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
|
||||
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
||||
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
||||
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ")
|
||||
StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ")
|
||||
.append(word.getElementFrequency()).append(" ")
|
||||
.append(vectors.getVocab().docAppearedIn(word.getLabel()));
|
||||
|
||||
|
@ -638,7 +650,7 @@ public class WordVectorSerializer {
|
|||
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
|
||||
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
||||
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
||||
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
|
||||
StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
|
||||
for (int code : word.getCodes()) {
|
||||
builder.append(code).append(" ");
|
||||
}
|
||||
|
@ -656,7 +668,7 @@ public class WordVectorSerializer {
|
|||
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
|
||||
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
||||
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
||||
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
|
||||
StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
|
||||
for (int point : word.getPoints()) {
|
||||
builder.append(point).append(" ");
|
||||
}
|
||||
|
@ -677,7 +689,7 @@ public class WordVectorSerializer {
|
|||
StringBuilder builder = new StringBuilder();
|
||||
for (VocabWord word : vectors.getVocab().tokens()) {
|
||||
if (word.isLabel())
|
||||
builder.append(encodeB64(word.getLabel())).append("\n");
|
||||
builder.append(ReadHelper.encodeB64(word.getLabel())).append("\n");
|
||||
}
|
||||
IOUtils.write(builder.toString().trim(), zipfile, StandardCharsets.UTF_8);
|
||||
|
||||
|
@ -688,7 +700,7 @@ public class WordVectorSerializer {
|
|||
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
|
||||
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
|
||||
VocabWord word = vectors.getVocab().elementAtIndex(i);
|
||||
builder = new StringBuilder(encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency())
|
||||
builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency())
|
||||
.append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel()));
|
||||
|
||||
writer.println(builder.toString().trim());
|
||||
|
@ -744,7 +756,7 @@ public class WordVectorSerializer {
|
|||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
VocabWord word = vectors.getVocab().tokenFor(decodeB64(line.trim()));
|
||||
VocabWord word = vectors.getVocab().tokenFor(ReadHelper.decodeB64(line.trim()));
|
||||
if (word != null) {
|
||||
word.markAsLabel(true);
|
||||
}
|
||||
|
@ -836,7 +848,7 @@ public class WordVectorSerializer {
|
|||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
String[] split = line.split(" ");
|
||||
VocabWord word = w2v.getVocab().tokenFor(decodeB64(split[0]));
|
||||
VocabWord word = w2v.getVocab().tokenFor(ReadHelper.decodeB64(split[0]));
|
||||
word.setElementFrequency((long) Double.parseDouble(split[1]));
|
||||
word.setSequencesCount((long) Double.parseDouble(split[2]));
|
||||
}
|
||||
|
@ -946,7 +958,7 @@ public class WordVectorSerializer {
|
|||
reader = new BufferedReader(new FileReader(h_points));
|
||||
while ((line = reader.readLine()) != null) {
|
||||
String[] split = line.split(" ");
|
||||
VocabWord word = vocab.wordFor(decodeB64(split[0]));
|
||||
VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0]));
|
||||
List<Integer> points = new ArrayList<>();
|
||||
for (int i = 1; i < split.length; i++) {
|
||||
points.add(Integer.parseInt(split[i]));
|
||||
|
@ -960,7 +972,7 @@ public class WordVectorSerializer {
|
|||
reader = new BufferedReader(new FileReader(h_codes));
|
||||
while ((line = reader.readLine()) != null) {
|
||||
String[] split = line.split(" ");
|
||||
VocabWord word = vocab.wordFor(decodeB64(split[0]));
|
||||
VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0]));
|
||||
List<Byte> codes = new ArrayList<>();
|
||||
for (int i = 1; i < split.length; i++) {
|
||||
codes.add(Byte.parseByte(split[i]));
|
||||
|
@ -1704,7 +1716,7 @@ public class WordVectorSerializer {
|
|||
if (line.isEmpty())
|
||||
line = iter.nextLine();
|
||||
String[] split = line.split(" ");
|
||||
String word = decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
|
||||
String word = ReadHelper.decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
|
||||
VocabWord word1 = new VocabWord(1.0, word);
|
||||
|
||||
word1.setIndex(cache.numWords());
|
||||
|
@ -1994,7 +2006,13 @@ public class WordVectorSerializer {
|
|||
private static final String SYN1_ENTRY = "syn1.bin";
|
||||
private static final String SYN1_NEG_ENTRY = "syn1neg.bin";
|
||||
|
||||
|
||||
/**
|
||||
* This method saves specified SequenceVectors model to target OutputStream
|
||||
*
|
||||
* @param vectors SequenceVectors model
|
||||
* @param stream Target output stream
|
||||
* @param <T>
|
||||
*/
|
||||
public static <T extends SequenceElement> void writeSequenceVectors(@NonNull SequenceVectors<T> vectors,
|
||||
@NonNull OutputStream stream)
|
||||
throws IOException {
|
||||
|
@ -2040,7 +2058,13 @@ public class WordVectorSerializer {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* This method loads SequenceVectors from specified file path
|
||||
*
|
||||
* @param path String
|
||||
* @param readExtendedTables boolean
|
||||
* @param <T>
|
||||
*/
|
||||
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull String path,
|
||||
boolean readExtendedTables)
|
||||
throws IOException {
|
||||
|
@ -2050,6 +2074,14 @@ public class WordVectorSerializer {
|
|||
return vectors;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method loads SequenceVectors from specified file path
|
||||
*
|
||||
* @param file File
|
||||
* @param readExtendedTables boolean
|
||||
* @param <T>
|
||||
*/
|
||||
|
||||
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull File file,
|
||||
boolean readExtendedTables)
|
||||
throws IOException {
|
||||
|
@ -2058,6 +2090,13 @@ public class WordVectorSerializer {
|
|||
return vectors;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method loads SequenceVectors from specified input stream
|
||||
*
|
||||
* @param stream InputStream
|
||||
* @param readExtendedTables boolean
|
||||
* @param <T>
|
||||
*/
|
||||
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull InputStream stream,
|
||||
boolean readExtendedTables)
|
||||
throws IOException {
|
||||
|
@ -2381,6 +2420,12 @@ public class WordVectorSerializer {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This method loads Word2Vec model from binary file
|
||||
*
|
||||
* @param file File
|
||||
* @return Word2Vec
|
||||
*/
|
||||
public static Word2Vec readAsBinary(@NonNull File file) {
|
||||
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
||||
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
||||
|
@ -2403,6 +2448,12 @@ public class WordVectorSerializer {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This method loads Word2Vec model from csv file
|
||||
*
|
||||
* @param file File
|
||||
* @return Word2Vec
|
||||
*/
|
||||
public static Word2Vec readAsCsv(@NonNull File file) {
|
||||
|
||||
Word2Vec vec;
|
||||
|
@ -2491,7 +2542,7 @@ public class WordVectorSerializer {
|
|||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
String[] split = line.split(" ");
|
||||
VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0]));
|
||||
VocabWord word = new VocabWord(Double.valueOf(split[1]), ReadHelper.decodeB64(split[0]));
|
||||
word.setIndex(cnt.getAndIncrement());
|
||||
word.incrementSequencesCount(Long.valueOf(split[2]));
|
||||
|
||||
|
@ -2669,7 +2720,7 @@ public class WordVectorSerializer {
|
|||
*
|
||||
* In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
|
||||
*
|
||||
* @param file File should point to previously saved w2v model
|
||||
* @param inputStream InputStream should point to previously saved w2v model
|
||||
* @return
|
||||
*/
|
||||
public static WordVectors loadStaticModel(InputStream inputStream) throws IOException {
|
||||
|
@ -2685,6 +2736,17 @@ public class WordVectorSerializer {
|
|||
}
|
||||
|
||||
// TODO: this method needs better name :)
|
||||
/**
|
||||
* This method restores previously saved w2v model. File can be in one of the following formats:
|
||||
* 1) Binary model, either compressed or not. Like well-known Google Model
|
||||
* 2) Popular CSV word2vec text format
|
||||
* 3) DL4j compressed format
|
||||
*
|
||||
* In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
|
||||
*
|
||||
* @param file File
|
||||
* @return
|
||||
*/
|
||||
public static WordVectors loadStaticModel(@NonNull File file) {
|
||||
if (!file.exists() || file.isDirectory())
|
||||
throw new RuntimeException(
|
||||
|
@ -2843,8 +2905,8 @@ public class WordVectorSerializer {
|
|||
throw new RuntimeException(e);
|
||||
}
|
||||
try {
|
||||
numWords = Integer.parseInt(readString(stream));
|
||||
vectorLength = Integer.parseInt(readString(stream));
|
||||
numWords = Integer.parseInt(ReadHelper.readString(stream));
|
||||
vectorLength = Integer.parseInt(ReadHelper.readString(stream));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
@ -2858,13 +2920,13 @@ public class WordVectorSerializer {
|
|||
@Override
|
||||
public Pair<VocabWord, float[]> next() {
|
||||
try {
|
||||
String word = readString(stream);
|
||||
String word = ReadHelper.readString(stream);
|
||||
VocabWord element = new VocabWord(1.0, word);
|
||||
element.setIndex(idxCounter.getAndIncrement());
|
||||
|
||||
float[] vector = new float[vectorLength];
|
||||
for (int i = 0; i < vectorLength; i++) {
|
||||
vector[i] = readFloat(stream);
|
||||
vector[i] = ReadHelper.readFloat(stream);
|
||||
}
|
||||
|
||||
return Pair.makePair(element, vector);
|
||||
|
@ -2913,7 +2975,7 @@ public class WordVectorSerializer {
|
|||
|
||||
String[] split = nextLine.split(" ");
|
||||
|
||||
VocabWord word = new VocabWord(1.0, decodeB64(split[0]));
|
||||
VocabWord word = new VocabWord(1.0, ReadHelper.decodeB64(split[0]));
|
||||
word.setIndex(idxCounter.getAndIncrement());
|
||||
|
||||
float[] vector = new float[split.length - 1];
|
||||
|
@ -2937,26 +2999,12 @@ public class WordVectorSerializer {
|
|||
}
|
||||
}
|
||||
|
||||
public static String encodeB64(String word) {
|
||||
try {
|
||||
return "B64:" + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", "");
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static String decodeB64(String word) {
|
||||
if (word.startsWith("B64:")) {
|
||||
String arp = word.replaceFirst("B64:", "");
|
||||
try {
|
||||
return new String(Base64.decodeBase64(arp), "UTF-8");
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
} else
|
||||
return word;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method saves Word2Vec model to output stream
|
||||
*
|
||||
* @param word2Vec Word2Vec
|
||||
* @param stream OutputStream
|
||||
*/
|
||||
public static void writeWord2Vec(@NonNull Word2Vec word2Vec, @NonNull OutputStream stream)
|
||||
throws IOException {
|
||||
|
||||
|
@ -2968,6 +3016,13 @@ public class WordVectorSerializer {
|
|||
writeSequenceVectors(vectors, stream);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method restores Word2Vec model from file
|
||||
*
|
||||
* @param path String
|
||||
* @param readExtendedTables booleab
|
||||
* @return Word2Vec
|
||||
*/
|
||||
public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables)
|
||||
throws IOException {
|
||||
|
||||
|
@ -2976,6 +3031,12 @@ public class WordVectorSerializer {
|
|||
return word2Vec;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method saves table of weights to file
|
||||
*
|
||||
* @param weightLookupTable WeightLookupTable
|
||||
* @param file File
|
||||
*/
|
||||
public static <T extends SequenceElement> void writeLookupTable(WeightLookupTable<T> weightLookupTable,
|
||||
@NonNull File file) throws IOException {
|
||||
try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file),
|
||||
|
@ -3038,7 +3099,7 @@ public class WordVectorSerializer {
|
|||
headerRead = true;
|
||||
weightLookupTable = new InMemoryLookupTable.Builder().cache(vocabCache).vectorLength(layerSize).build();
|
||||
} else {
|
||||
String label = decodeB64(tokens[0]);
|
||||
String label = ReadHelper.decodeB64(tokens[0]);
|
||||
int freq = Integer.parseInt(tokens[1]);
|
||||
int rows = Integer.parseInt(tokens[2]);
|
||||
int cols = Integer.parseInt(tokens[3]);
|
||||
|
@ -3071,6 +3132,13 @@ public class WordVectorSerializer {
|
|||
return weightLookupTable;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method loads Word2Vec model from file
|
||||
*
|
||||
* @param file File
|
||||
* @param readExtendedTables boolean
|
||||
* @return Word2Vec
|
||||
*/
|
||||
public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables)
|
||||
throws IOException {
|
||||
|
||||
|
@ -3078,6 +3146,13 @@ public class WordVectorSerializer {
|
|||
return word2Vec;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method loads Word2Vec model from input stream
|
||||
*
|
||||
* @param stream InputStream
|
||||
* @param readExtendedTable boolean
|
||||
* @return Word2Vec
|
||||
*/
|
||||
public static Word2Vec readWord2Vec(@NonNull InputStream stream,
|
||||
boolean readExtendedTable) throws IOException {
|
||||
SequenceVectors<VocabWord> vectors = readSequenceVectors(stream, readExtendedTable);
|
||||
|
@ -3087,7 +3162,13 @@ public class WordVectorSerializer {
|
|||
word2Vec.setModelUtils(vectors.getModelUtils());
|
||||
return word2Vec;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* This method loads FastText model to file
|
||||
*
|
||||
* @param vectors FastText
|
||||
* @param path File
|
||||
*/
|
||||
public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException {
|
||||
ObjectOutputStream outputStream = null;
|
||||
try {
|
||||
|
@ -3106,6 +3187,11 @@ public class WordVectorSerializer {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This method unloads FastText model from file
|
||||
*
|
||||
* @param path File
|
||||
*/
|
||||
public static FastText readWordVectors(File path) {
|
||||
FastText result = null;
|
||||
try {
|
||||
|
@ -3124,6 +3210,13 @@ public class WordVectorSerializer {
|
|||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method prints memory usage to log
|
||||
*
|
||||
* @param numWords
|
||||
* @param vectorLength
|
||||
* @param numTables
|
||||
*/
|
||||
public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) {
|
||||
double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables;
|
||||
|
||||
|
@ -3144,4 +3237,102 @@ public class WordVectorSerializer {
|
|||
OneTimeLogger.info(log, "Projected memory use for model: [{} {}]", String.format("%.2f", value), sfx);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper static methods to read data from input stream.
|
||||
*/
|
||||
private static class ReadHelper {
|
||||
/**
|
||||
* Read a float from a data input stream Credit to:
|
||||
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
|
||||
*
|
||||
* @param is
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
private static float readFloat(InputStream is) throws IOException {
|
||||
byte[] bytes = new byte[4];
|
||||
is.read(bytes);
|
||||
return getFloat(bytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read a string from a data input stream Credit to:
|
||||
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
|
||||
*
|
||||
* @param b
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
private static float getFloat(byte[] b) {
|
||||
int accum = 0;
|
||||
accum = accum | (b[0] & 0xff) << 0;
|
||||
accum = accum | (b[1] & 0xff) << 8;
|
||||
accum = accum | (b[2] & 0xff) << 16;
|
||||
accum = accum | (b[3] & 0xff) << 24;
|
||||
return Float.intBitsToFloat(accum);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read a string from a data input stream Credit to:
|
||||
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
|
||||
*
|
||||
* @param dis
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
private static String readString(DataInputStream dis) throws IOException {
|
||||
byte[] bytes = new byte[MAX_SIZE];
|
||||
byte b = dis.readByte();
|
||||
int i = -1;
|
||||
StringBuilder sb = new StringBuilder();
|
||||
while (b != 32 && b != 10) {
|
||||
i++;
|
||||
bytes[i] = b;
|
||||
b = dis.readByte();
|
||||
if (i == 49) {
|
||||
sb.append(new String(bytes, "UTF-8"));
|
||||
i = -1;
|
||||
bytes = new byte[MAX_SIZE];
|
||||
}
|
||||
}
|
||||
sb.append(new String(bytes, 0, i + 1, "UTF-8"));
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
private static final String B64 = "B64:";
|
||||
|
||||
/**
|
||||
* Encode input string
|
||||
*
|
||||
* @param word String
|
||||
* @return String
|
||||
*/
|
||||
private static String encodeB64(String word) {
|
||||
try {
|
||||
return B64 + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", "");
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Encode input string
|
||||
*
|
||||
* @param word String
|
||||
* @return String
|
||||
*/
|
||||
|
||||
private static String decodeB64(String word) {
|
||||
if (word.startsWith(B64)) {
|
||||
String arp = word.replaceFirst(B64, "");
|
||||
try {
|
||||
return new String(Base64.decodeBase64(arp), "UTF-8");
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
} else
|
||||
return word;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue