diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
index cce6a740a..210ab7686 100755
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
@@ -24,7 +24,6 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.apache.commons.io.output.CloseShieldOutputStream;
-import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
@@ -52,7 +51,6 @@ import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.util.DL4JFileUtils;
-import org.nd4j.base.Preconditions;
import org.nd4j.compression.impl.NoOp;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
@@ -68,8 +66,6 @@ import org.nd4j.util.OneTimeLogger;
import java.io.*;
import java.nio.charset.StandardCharsets;
-import java.nio.file.Files;
-import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
@@ -78,6 +74,80 @@ import java.util.zip.*;
/**
* This is utility class, providing various methods for WordVectors serialization
*
+ * List of available serialization methods (please keep this list consistent with source code):
+ *
+ *
+ * - Serializers for Word2Vec:
+ * {@link #writeWordVectors(WeightLookupTable, File)}
+ * {@link #writeWordVectors(WeightLookupTable, OutputStream)}
+ * {@link #writeWord2VecModel(Word2Vec, File)}
+ * {@link #writeWord2VecModel(Word2Vec, String)}
+ * {@link #writeWord2VecModel(Word2Vec, OutputStream)}
+ *
+ * - Deserializers for Word2Vec:
+ * {@link #readWord2VecModel(File)}
+ * {@link #readWord2VecModel(String)}
+ * {@link #readWord2VecModel(File, boolean)}
+ * {@link #readWord2VecModel(String, boolean)}
+ * {@link #readAsBinaryNoLineBreaks(File)}
+ * {@link #readAsBinary(File)}
+ * {@link #readAsCsv(File)}
+ * {@link #readBinaryModel(File, boolean, boolean)}
+ * {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)}
+ * {@link #readWord2Vec(String, boolean)}
+ * {@link #readWord2Vec(File, boolean)}
+ * {@link #readWord2Vec(InputStream, boolean)}
+ *
+ * - Serializers for ParaVec:
+ * {@link #writeParagraphVectors(ParagraphVectors, File)}
+ * {@link #writeParagraphVectors(ParagraphVectors, String)}
+ * {@link #writeParagraphVectors(ParagraphVectors, OutputStream)}
+ *
+ * - Deserializers for ParaVec:
+ * {@link #readParagraphVectors(File)}
+ * {@link #readParagraphVectors(String)}
+ * {@link #readParagraphVectors(InputStream)}
+ *
+ * - Serializers for GloVe:
+ * {@link #writeWordVectors(Glove, File)}
+ * {@link #writeWordVectors(Glove, String)}
+ * {@link #writeWordVectors(Glove, OutputStream)}
+ *
+ * - Adapters
+ * {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
+ * {@link #fromPair(Pair)}
+ * {@link #loadTxt(File)}
+ *
+ * - Serializers to tSNE format
+ * {@link #writeTsneFormat(Glove, INDArray, File)}
+ * {@link #writeTsneFormat(Word2Vec, INDArray, File)}
+ *
+ * - FastText serializer:
+ * {@link #writeWordVectors(FastText, File)}
+ *
+ * - FastText deserializer:
+ * {@link #readWordVectors(File)}
+ *
+ * - SequenceVectors serializers:
+ * {@link #writeSequenceVectors(SequenceVectors, OutputStream)}
+ * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, File)}
+ * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, String)}
+ * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, OutputStream)}
+ * {@link #writeLookupTable(WeightLookupTable, File)}
+ * {@link #writeVocabCache(VocabCache, File)}
+ * {@link #writeVocabCache(VocabCache, OutputStream)}
+ *
+ * - SequenceVectors deserializers:
+ * {@link #readSequenceVectors(File, boolean)}
+ * {@link #readSequenceVectors(String, boolean)}
+ * {@link #readSequenceVectors(SequenceElementFactory, File)}
+ * {@link #readSequenceVectors(InputStream, boolean)}
+ * {@link #readSequenceVectors(SequenceElementFactory, InputStream)}
+ * {@link #readLookupTable(File)}
+ * {@link #readLookupTable(InputStream)}
+ *
+ *
+ *
* @author Adam Gibson
* @author raver119
* @author alexander@skymind.io
@@ -97,7 +167,7 @@ public class WordVectorSerializer {
* @throws IOException
* @throws NumberFormatException
*/
- private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
+ /*private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
InMemoryLookupTable lookupTable;
VocabCache cache;
INDArray syn0;
@@ -142,7 +212,7 @@ public class WordVectorSerializer {
ret.setLookupTable(lookupTable);
}
return ret;
- }
+ }*/
/**
* Read a binary word2vec file.
@@ -173,8 +243,8 @@ public class WordVectorSerializer {
try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName())
? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile));
DataInputStream dis = new DataInputStream(bis)) {
- words = Integer.parseInt(readString(dis));
- size = Integer.parseInt(readString(dis));
+ words = Integer.parseInt(ReadHelper.readString(dis));
+ size = Integer.parseInt(ReadHelper.readString(dis));
syn0 = Nd4j.create(words, size);
cache = new AbstractCache<>();
@@ -188,11 +258,11 @@ public class WordVectorSerializer {
float[] vector = new float[size];
for (int i = 0; i < words; i++) {
- word = readString(dis);
+ word = ReadHelper.readString(dis);
log.trace("Loading " + word + " with word " + i);
for (int j = 0; j < size; j++) {
- vector[j] = readFloat(dis);
+ vector[j] = ReadHelper.readFloat(dis);
}
if (cache.containsWord(word))
@@ -236,64 +306,6 @@ public class WordVectorSerializer {
}
- /**
- * Read a float from a data input stream Credit to:
- * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
- *
- * @param is
- * @return
- * @throws IOException
- */
- public static float readFloat(InputStream is) throws IOException {
- byte[] bytes = new byte[4];
- is.read(bytes);
- return getFloat(bytes);
- }
-
- /**
- * Read a string from a data input stream Credit to:
- * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
- *
- * @param b
- * @return
- * @throws IOException
- */
- public static float getFloat(byte[] b) {
- int accum = 0;
- accum = accum | (b[0] & 0xff) << 0;
- accum = accum | (b[1] & 0xff) << 8;
- accum = accum | (b[2] & 0xff) << 16;
- accum = accum | (b[3] & 0xff) << 24;
- return Float.intBitsToFloat(accum);
- }
-
- /**
- * Read a string from a data input stream Credit to:
- * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
- *
- * @param dis
- * @return
- * @throws IOException
- */
- public static String readString(DataInputStream dis) throws IOException {
- byte[] bytes = new byte[MAX_SIZE];
- byte b = dis.readByte();
- int i = -1;
- StringBuilder sb = new StringBuilder();
- while (b != 32 && b != 10) {
- i++;
- bytes[i] = b;
- b = dis.readByte();
- if (i == 49) {
- sb.append(new String(bytes, "UTF-8"));
- i = -1;
- bytes = new byte[MAX_SIZE];
- }
- }
- sb.append(new String(bytes, 0, i + 1, "UTF-8"));
- return sb.toString();
- }
-
/**
* This method writes word vectors to the given path.
* Please note: this method doesn't load whole vocab/lookupTable into memory, so it's able to process large vocabularies served over network.
@@ -355,7 +367,7 @@ public class WordVectorSerializer {
val builder = new StringBuilder();
val l = element.getLabel();
- builder.append(encodeB64(l)).append(" ");
+ builder.append(ReadHelper.encodeB64(l)).append(" ");
val vec = lookupTable.vector(element.getLabel());
for (int i = 0; i < vec.length(); i++) {
builder.append(vec.getDouble(i));
@@ -518,7 +530,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i);
- StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
+ StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int code : word.getCodes()) {
builder.append(code).append(" ");
}
@@ -536,7 +548,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i);
- StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
+ StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int point : word.getPoints()) {
builder.append(point).append(" ");
}
@@ -554,7 +566,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i);
- StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ")
+ StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ")
.append(word.getElementFrequency()).append(" ")
.append(vectors.getVocab().docAppearedIn(word.getLabel()));
@@ -638,7 +650,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i);
- StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
+ StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int code : word.getCodes()) {
builder.append(code).append(" ");
}
@@ -656,7 +668,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i);
- StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
+ StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ");
for (int point : word.getPoints()) {
builder.append(point).append(" ");
}
@@ -677,7 +689,7 @@ public class WordVectorSerializer {
StringBuilder builder = new StringBuilder();
for (VocabWord word : vectors.getVocab().tokens()) {
if (word.isLabel())
- builder.append(encodeB64(word.getLabel())).append("\n");
+ builder.append(ReadHelper.encodeB64(word.getLabel())).append("\n");
}
IOUtils.write(builder.toString().trim(), zipfile, StandardCharsets.UTF_8);
@@ -688,7 +700,7 @@ public class WordVectorSerializer {
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
VocabWord word = vectors.getVocab().elementAtIndex(i);
- builder = new StringBuilder(encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency())
+ builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency())
.append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel()));
writer.println(builder.toString().trim());
@@ -744,7 +756,7 @@ public class WordVectorSerializer {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
- VocabWord word = vectors.getVocab().tokenFor(decodeB64(line.trim()));
+ VocabWord word = vectors.getVocab().tokenFor(ReadHelper.decodeB64(line.trim()));
if (word != null) {
word.markAsLabel(true);
}
@@ -836,7 +848,7 @@ public class WordVectorSerializer {
String line;
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
- VocabWord word = w2v.getVocab().tokenFor(decodeB64(split[0]));
+ VocabWord word = w2v.getVocab().tokenFor(ReadHelper.decodeB64(split[0]));
word.setElementFrequency((long) Double.parseDouble(split[1]));
word.setSequencesCount((long) Double.parseDouble(split[2]));
}
@@ -946,7 +958,7 @@ public class WordVectorSerializer {
reader = new BufferedReader(new FileReader(h_points));
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
- VocabWord word = vocab.wordFor(decodeB64(split[0]));
+ VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0]));
List points = new ArrayList<>();
for (int i = 1; i < split.length; i++) {
points.add(Integer.parseInt(split[i]));
@@ -960,7 +972,7 @@ public class WordVectorSerializer {
reader = new BufferedReader(new FileReader(h_codes));
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
- VocabWord word = vocab.wordFor(decodeB64(split[0]));
+ VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0]));
List codes = new ArrayList<>();
for (int i = 1; i < split.length; i++) {
codes.add(Byte.parseByte(split[i]));
@@ -1704,7 +1716,7 @@ public class WordVectorSerializer {
if (line.isEmpty())
line = iter.nextLine();
String[] split = line.split(" ");
- String word = decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
+ String word = ReadHelper.decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
VocabWord word1 = new VocabWord(1.0, word);
word1.setIndex(cache.numWords());
@@ -1994,7 +2006,13 @@ public class WordVectorSerializer {
private static final String SYN1_ENTRY = "syn1.bin";
private static final String SYN1_NEG_ENTRY = "syn1neg.bin";
-
+ /**
+ * This method saves specified SequenceVectors model to target OutputStream
+ *
+ * @param vectors SequenceVectors model
+ * @param stream Target output stream
+ * @param
+ */
public static void writeSequenceVectors(@NonNull SequenceVectors vectors,
@NonNull OutputStream stream)
throws IOException {
@@ -2040,7 +2058,13 @@ public class WordVectorSerializer {
}
}
-
+ /**
+ * This method loads SequenceVectors from specified file path
+ *
+ * @param path String
+ * @param readExtendedTables boolean
+ * @param
+ */
public static SequenceVectors readSequenceVectors(@NonNull String path,
boolean readExtendedTables)
throws IOException {
@@ -2050,6 +2074,14 @@ public class WordVectorSerializer {
return vectors;
}
+ /**
+ * This method loads SequenceVectors from specified file path
+ *
+ * @param file File
+ * @param readExtendedTables boolean
+ * @param
+ */
+
public static SequenceVectors readSequenceVectors(@NonNull File file,
boolean readExtendedTables)
throws IOException {
@@ -2058,6 +2090,13 @@ public class WordVectorSerializer {
return vectors;
}
+ /**
+ * This method loads SequenceVectors from specified input stream
+ *
+ * @param stream InputStream
+ * @param readExtendedTables boolean
+ * @param
+ */
public static SequenceVectors readSequenceVectors(@NonNull InputStream stream,
boolean readExtendedTables)
throws IOException {
@@ -2381,6 +2420,12 @@ public class WordVectorSerializer {
}
}
+ /**
+ * This method loads Word2Vec model from binary file
+ *
+ * @param file File
+ * @return Word2Vec
+ */
public static Word2Vec readAsBinary(@NonNull File file) {
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
@@ -2403,6 +2448,12 @@ public class WordVectorSerializer {
}
}
+ /**
+ * This method loads Word2Vec model from csv file
+ *
+ * @param file File
+ * @return Word2Vec
+ */
public static Word2Vec readAsCsv(@NonNull File file) {
Word2Vec vec;
@@ -2491,7 +2542,7 @@ public class WordVectorSerializer {
String line;
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
- VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0]));
+ VocabWord word = new VocabWord(Double.valueOf(split[1]), ReadHelper.decodeB64(split[0]));
word.setIndex(cnt.getAndIncrement());
word.incrementSequencesCount(Long.valueOf(split[2]));
@@ -2669,7 +2720,7 @@ public class WordVectorSerializer {
*
* In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
*
- * @param file File should point to previously saved w2v model
+ * @param inputStream InputStream should point to previously saved w2v model
* @return
*/
public static WordVectors loadStaticModel(InputStream inputStream) throws IOException {
@@ -2685,6 +2736,17 @@ public class WordVectorSerializer {
}
// TODO: this method needs better name :)
+ /**
+ * This method restores previously saved w2v model. File can be in one of the following formats:
+ * 1) Binary model, either compressed or not. Like well-known Google Model
+ * 2) Popular CSV word2vec text format
+ * 3) DL4j compressed format
+ *
+ * In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment.
+ *
+ * @param file File
+ * @return
+ */
public static WordVectors loadStaticModel(@NonNull File file) {
if (!file.exists() || file.isDirectory())
throw new RuntimeException(
@@ -2843,8 +2905,8 @@ public class WordVectorSerializer {
throw new RuntimeException(e);
}
try {
- numWords = Integer.parseInt(readString(stream));
- vectorLength = Integer.parseInt(readString(stream));
+ numWords = Integer.parseInt(ReadHelper.readString(stream));
+ vectorLength = Integer.parseInt(ReadHelper.readString(stream));
} catch (IOException e) {
throw new RuntimeException(e);
}
@@ -2858,13 +2920,13 @@ public class WordVectorSerializer {
@Override
public Pair next() {
try {
- String word = readString(stream);
+ String word = ReadHelper.readString(stream);
VocabWord element = new VocabWord(1.0, word);
element.setIndex(idxCounter.getAndIncrement());
float[] vector = new float[vectorLength];
for (int i = 0; i < vectorLength; i++) {
- vector[i] = readFloat(stream);
+ vector[i] = ReadHelper.readFloat(stream);
}
return Pair.makePair(element, vector);
@@ -2913,7 +2975,7 @@ public class WordVectorSerializer {
String[] split = nextLine.split(" ");
- VocabWord word = new VocabWord(1.0, decodeB64(split[0]));
+ VocabWord word = new VocabWord(1.0, ReadHelper.decodeB64(split[0]));
word.setIndex(idxCounter.getAndIncrement());
float[] vector = new float[split.length - 1];
@@ -2937,26 +2999,12 @@ public class WordVectorSerializer {
}
}
- public static String encodeB64(String word) {
- try {
- return "B64:" + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", "");
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
- public static String decodeB64(String word) {
- if (word.startsWith("B64:")) {
- String arp = word.replaceFirst("B64:", "");
- try {
- return new String(Base64.decodeBase64(arp), "UTF-8");
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- } else
- return word;
- }
-
+ /**
+ * This method saves Word2Vec model to output stream
+ *
+ * @param word2Vec Word2Vec
+ * @param stream OutputStream
+ */
public static void writeWord2Vec(@NonNull Word2Vec word2Vec, @NonNull OutputStream stream)
throws IOException {
@@ -2968,6 +3016,13 @@ public class WordVectorSerializer {
writeSequenceVectors(vectors, stream);
}
+ /**
+ * This method restores Word2Vec model from file
+ *
+ * @param path String
+ * @param readExtendedTables booleab
+ * @return Word2Vec
+ */
public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables)
throws IOException {
@@ -2976,6 +3031,12 @@ public class WordVectorSerializer {
return word2Vec;
}
+ /**
+ * This method saves table of weights to file
+ *
+ * @param weightLookupTable WeightLookupTable
+ * @param file File
+ */
public static void writeLookupTable(WeightLookupTable weightLookupTable,
@NonNull File file) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file),
@@ -3038,7 +3099,7 @@ public class WordVectorSerializer {
headerRead = true;
weightLookupTable = new InMemoryLookupTable.Builder().cache(vocabCache).vectorLength(layerSize).build();
} else {
- String label = decodeB64(tokens[0]);
+ String label = ReadHelper.decodeB64(tokens[0]);
int freq = Integer.parseInt(tokens[1]);
int rows = Integer.parseInt(tokens[2]);
int cols = Integer.parseInt(tokens[3]);
@@ -3071,6 +3132,13 @@ public class WordVectorSerializer {
return weightLookupTable;
}
+ /**
+ * This method loads Word2Vec model from file
+ *
+ * @param file File
+ * @param readExtendedTables boolean
+ * @return Word2Vec
+ */
public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables)
throws IOException {
@@ -3078,6 +3146,13 @@ public class WordVectorSerializer {
return word2Vec;
}
+ /**
+ * This method loads Word2Vec model from input stream
+ *
+ * @param stream InputStream
+ * @param readExtendedTable boolean
+ * @return Word2Vec
+ */
public static Word2Vec readWord2Vec(@NonNull InputStream stream,
boolean readExtendedTable) throws IOException {
SequenceVectors vectors = readSequenceVectors(stream, readExtendedTable);
@@ -3087,7 +3162,13 @@ public class WordVectorSerializer {
word2Vec.setModelUtils(vectors.getModelUtils());
return word2Vec;
}
-
+
+ /**
+ * This method loads FastText model to file
+ *
+ * @param vectors FastText
+ * @param path File
+ */
public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException {
ObjectOutputStream outputStream = null;
try {
@@ -3106,6 +3187,11 @@ public class WordVectorSerializer {
}
}
+ /**
+ * This method unloads FastText model from file
+ *
+ * @param path File
+ */
public static FastText readWordVectors(File path) {
FastText result = null;
try {
@@ -3124,6 +3210,13 @@ public class WordVectorSerializer {
return result;
}
+ /**
+ * This method prints memory usage to log
+ *
+ * @param numWords
+ * @param vectorLength
+ * @param numTables
+ */
public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) {
double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables;
@@ -3144,4 +3237,102 @@ public class WordVectorSerializer {
OneTimeLogger.info(log, "Projected memory use for model: [{} {}]", String.format("%.2f", value), sfx);
}
+
+ /**
+ * Helper static methods to read data from input stream.
+ */
+ private static class ReadHelper {
+ /**
+ * Read a float from a data input stream Credit to:
+ * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
+ *
+ * @param is
+ * @return
+ * @throws IOException
+ */
+ private static float readFloat(InputStream is) throws IOException {
+ byte[] bytes = new byte[4];
+ is.read(bytes);
+ return getFloat(bytes);
+ }
+
+ /**
+ * Read a string from a data input stream Credit to:
+ * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
+ *
+ * @param b
+ * @return
+ * @throws IOException
+ */
+ private static float getFloat(byte[] b) {
+ int accum = 0;
+ accum = accum | (b[0] & 0xff) << 0;
+ accum = accum | (b[1] & 0xff) << 8;
+ accum = accum | (b[2] & 0xff) << 16;
+ accum = accum | (b[3] & 0xff) << 24;
+ return Float.intBitsToFloat(accum);
+ }
+
+ /**
+ * Read a string from a data input stream Credit to:
+ * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java
+ *
+ * @param dis
+ * @return
+ * @throws IOException
+ */
+ private static String readString(DataInputStream dis) throws IOException {
+ byte[] bytes = new byte[MAX_SIZE];
+ byte b = dis.readByte();
+ int i = -1;
+ StringBuilder sb = new StringBuilder();
+ while (b != 32 && b != 10) {
+ i++;
+ bytes[i] = b;
+ b = dis.readByte();
+ if (i == 49) {
+ sb.append(new String(bytes, "UTF-8"));
+ i = -1;
+ bytes = new byte[MAX_SIZE];
+ }
+ }
+ sb.append(new String(bytes, 0, i + 1, "UTF-8"));
+ return sb.toString();
+ }
+
+ private static final String B64 = "B64:";
+
+ /**
+ * Encode input string
+ *
+ * @param word String
+ * @return String
+ */
+ private static String encodeB64(String word) {
+ try {
+ return B64 + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", "");
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * Encode input string
+ *
+ * @param word String
+ * @return String
+ */
+
+ private static String decodeB64(String word) {
+ if (word.startsWith(B64)) {
+ String arp = word.replaceFirst(B64, "");
+ try {
+ return new String(Base64.decodeBase64(arp), "UTF-8");
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ } else
+ return word;
+ }
+ }
}