Documentation from serialization/deserialization in NLP (#221)

* refactoring

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Javadocs

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Javadoc fixed

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Cleanup

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
Alexander Stoyakin 2019-09-02 17:17:55 +03:00 committed by raver119
parent 2129d5bcac
commit 90b62c4579
1 changed files with 303 additions and 112 deletions

View File

@ -24,7 +24,6 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator; import org.apache.commons.io.LineIterator;
import org.apache.commons.io.output.CloseShieldOutputStream; import org.apache.commons.io.output.CloseShieldOutputStream;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
@ -52,7 +51,6 @@ import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess; import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.util.DL4JFileUtils; import org.deeplearning4j.util.DL4JFileUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.compression.impl.NoOp; import org.nd4j.compression.impl.NoOp;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
@ -68,8 +66,6 @@ import org.nd4j.util.OneTimeLogger;
import java.io.*; import java.io.*;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@ -78,6 +74,80 @@ import java.util.zip.*;
/** /**
* This is utility class, providing various methods for WordVectors serialization * This is utility class, providing various methods for WordVectors serialization
* *
* List of available serialization methods (please keep this list consistent with source code):
*
* <ul>
* <li>Serializers for Word2Vec:</li>
* {@link #writeWordVectors(WeightLookupTable, File)}
* {@link #writeWordVectors(WeightLookupTable, OutputStream)}
* {@link #writeWord2VecModel(Word2Vec, File)}
* {@link #writeWord2VecModel(Word2Vec, String)}
* {@link #writeWord2VecModel(Word2Vec, OutputStream)}
*
* <li>Deserializers for Word2Vec:</li>
* {@link #readWord2VecModel(File)}
* {@link #readWord2VecModel(String)}
* {@link #readWord2VecModel(File, boolean)}
* {@link #readWord2VecModel(String, boolean)}
* {@link #readAsBinaryNoLineBreaks(File)}
* {@link #readAsBinary(File)}
* {@link #readAsCsv(File)}
* {@link #readBinaryModel(File, boolean, boolean)}
* {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)}
* {@link #readWord2Vec(String, boolean)}
* {@link #readWord2Vec(File, boolean)}
* {@link #readWord2Vec(InputStream, boolean)}
*
* <li>Serializers for ParaVec:</li>
* {@link #writeParagraphVectors(ParagraphVectors, File)}
* {@link #writeParagraphVectors(ParagraphVectors, String)}
* {@link #writeParagraphVectors(ParagraphVectors, OutputStream)}
*
* <li>Deserializers for ParaVec:</li>
* {@link #readParagraphVectors(File)}
* {@link #readParagraphVectors(String)}
* {@link #readParagraphVectors(InputStream)}
*
* <li>Serializers for GloVe:</li>
* {@link #writeWordVectors(Glove, File)}
* {@link #writeWordVectors(Glove, String)}
* {@link #writeWordVectors(Glove, OutputStream)}
*
* <li>Adapters</li>
* {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
* {@link #fromPair(Pair)}
* {@link #loadTxt(File)}
*
* <li>Serializers to tSNE format</li>
* {@link #writeTsneFormat(Glove, INDArray, File)}
* {@link #writeTsneFormat(Word2Vec, INDArray, File)}
*
* <li>FastText serializer:</li>
* {@link #writeWordVectors(FastText, File)}
*
* <li>FastText deserializer:</li>
* {@link #readWordVectors(File)}
*
* <li>SequenceVectors serializers:</li>
* {@link #writeSequenceVectors(SequenceVectors, OutputStream)}
* {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, File)}
* {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, String)}
* {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, OutputStream)}
* {@link #writeLookupTable(WeightLookupTable, File)}
* {@link #writeVocabCache(VocabCache, File)}
* {@link #writeVocabCache(VocabCache, OutputStream)}
*
* <li>SequenceVectors deserializers:</li>
* {@link #readSequenceVectors(File, boolean)}
* {@link #readSequenceVectors(String, boolean)}
* {@link #readSequenceVectors(SequenceElementFactory, File)}
* {@link #readSequenceVectors(InputStream, boolean)}
* {@link #readSequenceVectors(SequenceElementFactory, InputStream)}
* {@link #readLookupTable(File)}
* {@link #readLookupTable(InputStream)}
*
* </ul>
*
* @author Adam Gibson * @author Adam Gibson
* @author raver119 * @author raver119
* @author alexander@skymind.io * @author alexander@skymind.io
@ -97,7 +167,7 @@ public class WordVectorSerializer {
* @throws IOException * @throws IOException
* @throws NumberFormatException * @throws NumberFormatException
*/ */
private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException { /*private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
InMemoryLookupTable lookupTable; InMemoryLookupTable lookupTable;
VocabCache cache; VocabCache cache;
INDArray syn0; INDArray syn0;
@ -142,7 +212,7 @@ public class WordVectorSerializer {
ret.setLookupTable(lookupTable); ret.setLookupTable(lookupTable);
} }
return ret; return ret;
} }*/
/** /**
* Read a binary word2vec file. * Read a binary word2vec file.
@ -173,8 +243,8 @@ public class WordVectorSerializer {
try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName()) try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName())
? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile)); ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile));
DataInputStream dis = new DataInputStream(bis)) { DataInputStream dis = new DataInputStream(bis)) {
words = Integer.parseInt(readString(dis)); words = Integer.parseInt(ReadHelper.readString(dis));
size = Integer.parseInt(readString(dis)); size = Integer.parseInt(ReadHelper.readString(dis));
syn0 = Nd4j.create(words, size); syn0 = Nd4j.create(words, size);
cache = new AbstractCache<>(); cache = new AbstractCache<>();
@ -188,11 +258,11 @@ public class WordVectorSerializer {
float[] vector = new float[size]; float[] vector = new float[size];
for (int i = 0; i < words; i++) { for (int i = 0; i < words; i++) {
word = readString(dis); word = ReadHelper.readString(dis);
log.trace("Loading " + word + " with word " + i); log.trace("Loading " + word + " with word " + i);
for (int j = 0; j < size; j++) { for (int j = 0; j < size; j++) {
vector[j] = readFloat(dis); vector[j] = ReadHelper.readFloat(dis);
} }
if (cache.containsWord(word)) if (cache.containsWord(word))
@ -236,64 +306,6 @@ public class WordVectorSerializer {
} }
/**
* Read a float from a data input stream Credit to:
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
*
* @param is
* @return
* @throws IOException
*/
public static float readFloat(InputStream is) throws IOException {
byte[] bytes = new byte[4];
is.read(bytes);
return getFloat(bytes);
}
/**
* Read a string from a data input stream Credit to:
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
*
* @param b
* @return
* @throws IOException
*/
public static float getFloat(byte[] b) {
int accum = 0;
accum = accum | (b[0] & 0xff) << 0;
accum = accum | (b[1] & 0xff) << 8;
accum = accum | (b[2] & 0xff) << 16;
accum = accum | (b[3] & 0xff) << 24;
return Float.intBitsToFloat(accum);
}
/**
* Read a string from a data input stream Credit to:
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
*
* @param dis
* @return
* @throws IOException
*/
public static String readString(DataInputStream dis) throws IOException {
byte[] bytes = new byte[MAX_SIZE];
byte b = dis.readByte();
int i = -1;
StringBuilder sb = new StringBuilder();
while (b != 32 && b != 10) {
i++;
bytes[i] = b;
b = dis.readByte();
if (i == 49) {
sb.append(new String(bytes, "UTF-8"));
i = -1;
bytes = new byte[MAX_SIZE];
}
}
sb.append(new String(bytes, 0, i + 1, "UTF-8"));
return sb.toString();
}
/** /**
* This method writes word vectors to the given path. * This method writes word vectors to the given path.
* Please note: this method doesn't load whole vocab/lookupTable into memory, so it's able to process large vocabularies served over network. * Please note: this method doesn't load whole vocab/lookupTable into memory, so it's able to process large vocabularies served over network.
@ -355,7 +367,7 @@ public class WordVectorSerializer {
val builder = new StringBuilder(); val builder = new StringBuilder();
val l = element.getLabel(); val l = element.getLabel();
builder.append(encodeB64(l)).append(" "); builder.append(ReadHelper.encodeB64(l)).append(" ");
val vec = lookupTable.vector(element.getLabel()); val vec = lookupTable.vector(element.getLabel());
for (int i = 0; i < vec.length(); i++) { for (int i = 0; i < vec.length(); i++) {
builder.append(vec.getDouble(i)); builder.append(vec.getDouble(i));
@ -518,7 +530,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) { for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i); VocabWord word = vectors.getVocab().elementAtIndex(i);
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int code : word.getCodes()) { for (int code : word.getCodes()) {
builder.append(code).append(" "); builder.append(code).append(" ");
} }
@ -536,7 +548,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) { for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i); VocabWord word = vectors.getVocab().elementAtIndex(i);
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int point : word.getPoints()) { for (int point : word.getPoints()) {
builder.append(point).append(" "); builder.append(point).append(" ");
} }
@ -554,7 +566,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) { for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i); VocabWord word = vectors.getVocab().elementAtIndex(i);
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ") StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ")
.append(word.getElementFrequency()).append(" ") .append(word.getElementFrequency()).append(" ")
.append(vectors.getVocab().docAppearedIn(word.getLabel())); .append(vectors.getVocab().docAppearedIn(word.getLabel()));
@ -638,7 +650,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) { for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i); VocabWord word = vectors.getVocab().elementAtIndex(i);
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int code : word.getCodes()) { for (int code : word.getCodes()) {
builder.append(code).append(" "); builder.append(code).append(" ");
} }
@ -656,7 +668,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) { for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i); VocabWord word = vectors.getVocab().elementAtIndex(i);
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int point : word.getPoints()) { for (int point : word.getPoints()) {
builder.append(point).append(" "); builder.append(point).append(" ");
} }
@ -677,7 +689,7 @@ public class WordVectorSerializer {
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
for (VocabWord word : vectors.getVocab().tokens()) { for (VocabWord word : vectors.getVocab().tokens()) {
if (word.isLabel()) if (word.isLabel())
builder.append(encodeB64(word.getLabel())).append("\n"); builder.append(ReadHelper.encodeB64(word.getLabel())).append("\n");
} }
IOUtils.write(builder.toString().trim(), zipfile, StandardCharsets.UTF_8); IOUtils.write(builder.toString().trim(), zipfile, StandardCharsets.UTF_8);
@ -688,7 +700,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) { for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i); VocabWord word = vectors.getVocab().elementAtIndex(i);
builder = new StringBuilder(encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency()) builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency())
.append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel())); .append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel()));
writer.println(builder.toString().trim()); writer.println(builder.toString().trim());
@ -744,7 +756,7 @@ public class WordVectorSerializer {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) { try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
VocabWord word = vectors.getVocab().tokenFor(decodeB64(line.trim())); VocabWord word = vectors.getVocab().tokenFor(ReadHelper.decodeB64(line.trim()));
if (word != null) { if (word != null) {
word.markAsLabel(true); word.markAsLabel(true);
} }
@ -836,7 +848,7 @@ public class WordVectorSerializer {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
String[] split = line.split(" "); String[] split = line.split(" ");
VocabWord word = w2v.getVocab().tokenFor(decodeB64(split[0])); VocabWord word = w2v.getVocab().tokenFor(ReadHelper.decodeB64(split[0]));
word.setElementFrequency((long) Double.parseDouble(split[1])); word.setElementFrequency((long) Double.parseDouble(split[1]));
word.setSequencesCount((long) Double.parseDouble(split[2])); word.setSequencesCount((long) Double.parseDouble(split[2]));
} }
@ -946,7 +958,7 @@ public class WordVectorSerializer {
reader = new BufferedReader(new FileReader(h_points)); reader = new BufferedReader(new FileReader(h_points));
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
String[] split = line.split(" "); String[] split = line.split(" ");
VocabWord word = vocab.wordFor(decodeB64(split[0])); VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0]));
List<Integer> points = new ArrayList<>(); List<Integer> points = new ArrayList<>();
for (int i = 1; i < split.length; i++) { for (int i = 1; i < split.length; i++) {
points.add(Integer.parseInt(split[i])); points.add(Integer.parseInt(split[i]));
@ -960,7 +972,7 @@ public class WordVectorSerializer {
reader = new BufferedReader(new FileReader(h_codes)); reader = new BufferedReader(new FileReader(h_codes));
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
String[] split = line.split(" "); String[] split = line.split(" ");
VocabWord word = vocab.wordFor(decodeB64(split[0])); VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0]));
List<Byte> codes = new ArrayList<>(); List<Byte> codes = new ArrayList<>();
for (int i = 1; i < split.length; i++) { for (int i = 1; i < split.length; i++) {
codes.add(Byte.parseByte(split[i])); codes.add(Byte.parseByte(split[i]));
@ -1704,7 +1716,7 @@ public class WordVectorSerializer {
if (line.isEmpty()) if (line.isEmpty())
line = iter.nextLine(); line = iter.nextLine();
String[] split = line.split(" "); String[] split = line.split(" ");
String word = decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " "); String word = ReadHelper.decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
VocabWord word1 = new VocabWord(1.0, word); VocabWord word1 = new VocabWord(1.0, word);
word1.setIndex(cache.numWords()); word1.setIndex(cache.numWords());
@ -1994,7 +2006,13 @@ public class WordVectorSerializer {
private static final String SYN1_ENTRY = "syn1.bin"; private static final String SYN1_ENTRY = "syn1.bin";
private static final String SYN1_NEG_ENTRY = "syn1neg.bin"; private static final String SYN1_NEG_ENTRY = "syn1neg.bin";
/**
* This method saves specified SequenceVectors model to target OutputStream
*
* @param vectors SequenceVectors model
* @param stream Target output stream
* @param <T>
*/
public static <T extends SequenceElement> void writeSequenceVectors(@NonNull SequenceVectors<T> vectors, public static <T extends SequenceElement> void writeSequenceVectors(@NonNull SequenceVectors<T> vectors,
@NonNull OutputStream stream) @NonNull OutputStream stream)
throws IOException { throws IOException {
@ -2040,7 +2058,13 @@ public class WordVectorSerializer {
} }
} }
/**
* This method loads SequenceVectors from specified file path
*
* @param path String
* @param readExtendedTables boolean
* @param <T>
*/
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull String path, public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull String path,
boolean readExtendedTables) boolean readExtendedTables)
throws IOException { throws IOException {
@ -2050,6 +2074,14 @@ public class WordVectorSerializer {
return vectors; return vectors;
} }
/**
* This method loads SequenceVectors from specified file path
*
* @param file File
* @param readExtendedTables boolean
* @param <T>
*/
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull File file, public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull File file,
boolean readExtendedTables) boolean readExtendedTables)
throws IOException { throws IOException {
@ -2058,6 +2090,13 @@ public class WordVectorSerializer {
return vectors; return vectors;
} }
/**
* This method loads SequenceVectors from specified input stream
*
* @param stream InputStream
* @param readExtendedTables boolean
* @param <T>
*/
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull InputStream stream, public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(@NonNull InputStream stream,
boolean readExtendedTables) boolean readExtendedTables)
throws IOException { throws IOException {
@ -2381,6 +2420,12 @@ public class WordVectorSerializer {
} }
} }
/**
* This method loads Word2Vec model from binary file
*
* @param file File
* @return Word2Vec
*/
public static Word2Vec readAsBinary(@NonNull File file) { public static Word2Vec readAsBinary(@NonNull File file) {
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive(); boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency(); int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
@ -2403,6 +2448,12 @@ public class WordVectorSerializer {
} }
} }
/**
* This method loads Word2Vec model from csv file
*
* @param file File
* @return Word2Vec
*/
public static Word2Vec readAsCsv(@NonNull File file) { public static Word2Vec readAsCsv(@NonNull File file) {
Word2Vec vec; Word2Vec vec;
@ -2491,7 +2542,7 @@ public class WordVectorSerializer {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
String[] split = line.split(" "); String[] split = line.split(" ");
VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0])); VocabWord word = new VocabWord(Double.valueOf(split[1]), ReadHelper.decodeB64(split[0]));
word.setIndex(cnt.getAndIncrement()); word.setIndex(cnt.getAndIncrement());
word.incrementSequencesCount(Long.valueOf(split[2])); word.incrementSequencesCount(Long.valueOf(split[2]));
@ -2669,7 +2720,7 @@ public class WordVectorSerializer {
* *
* In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment. * In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
* *
* @param file File should point to previously saved w2v model * @param inputStream InputStream should point to previously saved w2v model
* @return * @return
*/ */
public static WordVectors loadStaticModel(InputStream inputStream) throws IOException { public static WordVectors loadStaticModel(InputStream inputStream) throws IOException {
@ -2685,6 +2736,17 @@ public class WordVectorSerializer {
} }
// TODO: this method needs better name :) // TODO: this method needs better name :)
/**
* This method restores previously saved w2v model. File can be in one of the following formats:
* 1) Binary model, either compressed or not. Like well-known Google Model
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
*
* In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
*
* @param file File
* @return
*/
public static WordVectors loadStaticModel(@NonNull File file) { public static WordVectors loadStaticModel(@NonNull File file) {
if (!file.exists() || file.isDirectory()) if (!file.exists() || file.isDirectory())
throw new RuntimeException( throw new RuntimeException(
@ -2843,8 +2905,8 @@ public class WordVectorSerializer {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
try { try {
numWords = Integer.parseInt(readString(stream)); numWords = Integer.parseInt(ReadHelper.readString(stream));
vectorLength = Integer.parseInt(readString(stream)); vectorLength = Integer.parseInt(ReadHelper.readString(stream));
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -2858,13 +2920,13 @@ public class WordVectorSerializer {
@Override @Override
public Pair<VocabWord, float[]> next() { public Pair<VocabWord, float[]> next() {
try { try {
String word = readString(stream); String word = ReadHelper.readString(stream);
VocabWord element = new VocabWord(1.0, word); VocabWord element = new VocabWord(1.0, word);
element.setIndex(idxCounter.getAndIncrement()); element.setIndex(idxCounter.getAndIncrement());
float[] vector = new float[vectorLength]; float[] vector = new float[vectorLength];
for (int i = 0; i < vectorLength; i++) { for (int i = 0; i < vectorLength; i++) {
vector[i] = readFloat(stream); vector[i] = ReadHelper.readFloat(stream);
} }
return Pair.makePair(element, vector); return Pair.makePair(element, vector);
@ -2913,7 +2975,7 @@ public class WordVectorSerializer {
String[] split = nextLine.split(" "); String[] split = nextLine.split(" ");
VocabWord word = new VocabWord(1.0, decodeB64(split[0])); VocabWord word = new VocabWord(1.0, ReadHelper.decodeB64(split[0]));
word.setIndex(idxCounter.getAndIncrement()); word.setIndex(idxCounter.getAndIncrement());
float[] vector = new float[split.length - 1]; float[] vector = new float[split.length - 1];
@ -2937,26 +2999,12 @@ public class WordVectorSerializer {
} }
} }
public static String encodeB64(String word) { /**
try { * This method saves Word2Vec model to output stream
return "B64:" + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", ""); *
} catch (Exception e) { * @param word2Vec Word2Vec
throw new RuntimeException(e); * @param stream OutputStream
} */
}
public static String decodeB64(String word) {
if (word.startsWith("B64:")) {
String arp = word.replaceFirst("B64:", "");
try {
return new String(Base64.decodeBase64(arp), "UTF-8");
} catch (Exception e) {
throw new RuntimeException(e);
}
} else
return word;
}
public static void writeWord2Vec(@NonNull Word2Vec word2Vec, @NonNull OutputStream stream) public static void writeWord2Vec(@NonNull Word2Vec word2Vec, @NonNull OutputStream stream)
throws IOException { throws IOException {
@ -2968,6 +3016,13 @@ public class WordVectorSerializer {
writeSequenceVectors(vectors, stream); writeSequenceVectors(vectors, stream);
} }
/**
* This method restores Word2Vec model from file
*
* @param path String
* @param readExtendedTables booleab
* @return Word2Vec
*/
public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables) public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables)
throws IOException { throws IOException {
@ -2976,6 +3031,12 @@ public class WordVectorSerializer {
return word2Vec; return word2Vec;
} }
/**
* This method saves table of weights to file
*
* @param weightLookupTable WeightLookupTable
* @param file File
*/
public static <T extends SequenceElement> void writeLookupTable(WeightLookupTable<T> weightLookupTable, public static <T extends SequenceElement> void writeLookupTable(WeightLookupTable<T> weightLookupTable,
@NonNull File file) throws IOException { @NonNull File file) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file), try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file),
@ -3038,7 +3099,7 @@ public class WordVectorSerializer {
headerRead = true; headerRead = true;
weightLookupTable = new InMemoryLookupTable.Builder().cache(vocabCache).vectorLength(layerSize).build(); weightLookupTable = new InMemoryLookupTable.Builder().cache(vocabCache).vectorLength(layerSize).build();
} else { } else {
String label = decodeB64(tokens[0]); String label = ReadHelper.decodeB64(tokens[0]);
int freq = Integer.parseInt(tokens[1]); int freq = Integer.parseInt(tokens[1]);
int rows = Integer.parseInt(tokens[2]); int rows = Integer.parseInt(tokens[2]);
int cols = Integer.parseInt(tokens[3]); int cols = Integer.parseInt(tokens[3]);
@ -3071,6 +3132,13 @@ public class WordVectorSerializer {
return weightLookupTable; return weightLookupTable;
} }
/**
* This method loads Word2Vec model from file
*
* @param file File
* @param readExtendedTables boolean
* @return Word2Vec
*/
public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables) public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables)
throws IOException { throws IOException {
@ -3078,6 +3146,13 @@ public class WordVectorSerializer {
return word2Vec; return word2Vec;
} }
/**
* This method loads Word2Vec model from input stream
*
* @param stream InputStream
* @param readExtendedTable boolean
* @return Word2Vec
*/
public static Word2Vec readWord2Vec(@NonNull InputStream stream, public static Word2Vec readWord2Vec(@NonNull InputStream stream,
boolean readExtendedTable) throws IOException { boolean readExtendedTable) throws IOException {
SequenceVectors<VocabWord> vectors = readSequenceVectors(stream, readExtendedTable); SequenceVectors<VocabWord> vectors = readSequenceVectors(stream, readExtendedTable);
@ -3087,7 +3162,13 @@ public class WordVectorSerializer {
word2Vec.setModelUtils(vectors.getModelUtils()); word2Vec.setModelUtils(vectors.getModelUtils());
return word2Vec; return word2Vec;
} }
/**
* This method loads FastText model to file
*
* @param vectors FastText
* @param path File
*/
public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException { public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException {
ObjectOutputStream outputStream = null; ObjectOutputStream outputStream = null;
try { try {
@ -3106,6 +3187,11 @@ public class WordVectorSerializer {
} }
} }
/**
* This method unloads FastText model from file
*
* @param path File
*/
public static FastText readWordVectors(File path) { public static FastText readWordVectors(File path) {
FastText result = null; FastText result = null;
try { try {
@ -3124,6 +3210,13 @@ public class WordVectorSerializer {
return result; return result;
} }
/**
* This method prints memory usage to log
*
* @param numWords
* @param vectorLength
* @param numTables
*/
public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) { public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) {
double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables; double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables;
@ -3144,4 +3237,102 @@ public class WordVectorSerializer {
OneTimeLogger.info(log, "Projected memory use for model: [{} {}]", String.format("%.2f", value), sfx); OneTimeLogger.info(log, "Projected memory use for model: [{} {}]", String.format("%.2f", value), sfx);
} }
/**
* Helper static methods to read data from input stream.
*/
private static class ReadHelper {
/**
* Read a float from a data input stream Credit to:
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
*
* @param is
* @return
* @throws IOException
*/
private static float readFloat(InputStream is) throws IOException {
byte[] bytes = new byte[4];
is.read(bytes);
return getFloat(bytes);
}
/**
* Read a string from a data input stream Credit to:
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
*
* @param b
* @return
* @throws IOException
*/
private static float getFloat(byte[] b) {
int accum = 0;
accum = accum | (b[0] & 0xff) << 0;
accum = accum | (b[1] & 0xff) << 8;
accum = accum | (b[2] & 0xff) << 16;
accum = accum | (b[3] & 0xff) << 24;
return Float.intBitsToFloat(accum);
}
/**
* Read a string from a data input stream Credit to:
* https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
*
* @param dis
* @return
* @throws IOException
*/
private static String readString(DataInputStream dis) throws IOException {
byte[] bytes = new byte[MAX_SIZE];
byte b = dis.readByte();
int i = -1;
StringBuilder sb = new StringBuilder();
while (b != 32 && b != 10) {
i++;
bytes[i] = b;
b = dis.readByte();
if (i == 49) {
sb.append(new String(bytes, "UTF-8"));
i = -1;
bytes = new byte[MAX_SIZE];
}
}
sb.append(new String(bytes, 0, i + 1, "UTF-8"));
return sb.toString();
}
private static final String B64 = "B64:";
/**
* Encode input string
*
* @param word String
* @return String
*/
private static String encodeB64(String word) {
try {
return B64 + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", "");
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* Encode input string
*
* @param word String
* @return String
*/
private static String decodeB64(String word) {
if (word.startsWith(B64)) {
String arp = word.replaceFirst(B64, "");
try {
return new String(Base64.decodeBase64(arp), "UTF-8");
} catch (Exception e) {
throw new RuntimeException(e);
}
} else
return word;
}
}
} }