From 90b62c457917e20480ef3e3ec0d21c6819239650 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Mon, 2 Sep 2019 17:17:55 +0300 Subject: [PATCH] Documentation from serialization/deserialization in NLP (#221) * refactoring Signed-off-by: Alexander Stoyakin * Javadocs Signed-off-by: Alexander Stoyakin * Javadoc fixed Signed-off-by: Alexander Stoyakin * Cleanup Signed-off-by: Alexander Stoyakin --- .../loader/WordVectorSerializer.java | 415 +++++++++++++----- 1 file changed, 303 insertions(+), 112 deletions(-) diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java index cce6a740a..210ab7686 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java @@ -24,7 +24,6 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.io.LineIterator; import org.apache.commons.io.output.CloseShieldOutputStream; -import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; @@ -52,7 +51,6 @@ import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.deeplearning4j.util.DL4JFileUtils; -import org.nd4j.base.Preconditions; import org.nd4j.compression.impl.NoOp; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -68,8 +66,6 @@ import org.nd4j.util.OneTimeLogger; import java.io.*; import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; @@ -78,6 +74,80 @@ import java.util.zip.*; /** * This is utility class, providing various methods for WordVectors serialization * + * List of available serialization methods (please keep this list consistent with source code): + * + *
    + *
  • Serializers for Word2Vec:
  • + * {@link #writeWordVectors(WeightLookupTable, File)} + * {@link #writeWordVectors(WeightLookupTable, OutputStream)} + * {@link #writeWord2VecModel(Word2Vec, File)} + * {@link #writeWord2VecModel(Word2Vec, String)} + * {@link #writeWord2VecModel(Word2Vec, OutputStream)} + * + *
  • Deserializers for Word2Vec:
  • + * {@link #readWord2VecModel(File)} + * {@link #readWord2VecModel(String)} + * {@link #readWord2VecModel(File, boolean)} + * {@link #readWord2VecModel(String, boolean)} + * {@link #readAsBinaryNoLineBreaks(File)} + * {@link #readAsBinary(File)} + * {@link #readAsCsv(File)} + * {@link #readBinaryModel(File, boolean, boolean)} + * {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)} + * {@link #readWord2Vec(String, boolean)} + * {@link #readWord2Vec(File, boolean)} + * {@link #readWord2Vec(InputStream, boolean)} + * + *
  • Serializers for ParaVec:
  • + * {@link #writeParagraphVectors(ParagraphVectors, File)} + * {@link #writeParagraphVectors(ParagraphVectors, String)} + * {@link #writeParagraphVectors(ParagraphVectors, OutputStream)} + * + *
  • Deserializers for ParaVec:
  • + * {@link #readParagraphVectors(File)} + * {@link #readParagraphVectors(String)} + * {@link #readParagraphVectors(InputStream)} + * + *
  • Serializers for GloVe:
  • + * {@link #writeWordVectors(Glove, File)} + * {@link #writeWordVectors(Glove, String)} + * {@link #writeWordVectors(Glove, OutputStream)} + * + *
  • Adapters
  • + * {@link #fromTableAndVocab(WeightLookupTable, VocabCache)} + * {@link #fromPair(Pair)} + * {@link #loadTxt(File)} + * + *
  • Serializers to tSNE format
  • + * {@link #writeTsneFormat(Glove, INDArray, File)} + * {@link #writeTsneFormat(Word2Vec, INDArray, File)} + * + *
  • FastText serializer:
  • + * {@link #writeWordVectors(FastText, File)} + * + *
  • FastText deserializer:
  • + * {@link #readWordVectors(File)} + * + *
  • SequenceVectors serializers:
  • + * {@link #writeSequenceVectors(SequenceVectors, OutputStream)} + * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, File)} + * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, String)} + * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, OutputStream)} + * {@link #writeLookupTable(WeightLookupTable, File)} + * {@link #writeVocabCache(VocabCache, File)} + * {@link #writeVocabCache(VocabCache, OutputStream)} + * + *
  • SequenceVectors deserializers:
  • + * {@link #readSequenceVectors(File, boolean)} + * {@link #readSequenceVectors(String, boolean)} + * {@link #readSequenceVectors(SequenceElementFactory, File)} + * {@link #readSequenceVectors(InputStream, boolean)} + * {@link #readSequenceVectors(SequenceElementFactory, InputStream)} + * {@link #readLookupTable(File)} + * {@link #readLookupTable(InputStream)} + * + *
+ * * @author Adam Gibson * @author raver119 * @author alexander@skymind.io @@ -97,7 +167,7 @@ public class WordVectorSerializer { * @throws IOException * @throws NumberFormatException */ - private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException { + /*private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException { InMemoryLookupTable lookupTable; VocabCache cache; INDArray syn0; @@ -142,7 +212,7 @@ public class WordVectorSerializer { ret.setLookupTable(lookupTable); } return ret; - } + }*/ /** * Read a binary word2vec file. @@ -173,8 +243,8 @@ public class WordVectorSerializer { try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName()) ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile)); DataInputStream dis = new DataInputStream(bis)) { - words = Integer.parseInt(readString(dis)); - size = Integer.parseInt(readString(dis)); + words = Integer.parseInt(ReadHelper.readString(dis)); + size = Integer.parseInt(ReadHelper.readString(dis)); syn0 = Nd4j.create(words, size); cache = new AbstractCache<>(); @@ -188,11 +258,11 @@ public class WordVectorSerializer { float[] vector = new float[size]; for (int i = 0; i < words; i++) { - word = readString(dis); + word = ReadHelper.readString(dis); log.trace("Loading " + word + " with word " + i); for (int j = 0; j < size; j++) { - vector[j] = readFloat(dis); + vector[j] = ReadHelper.readFloat(dis); } if (cache.containsWord(word)) @@ -236,64 +306,6 @@ public class WordVectorSerializer { } - /** - * Read a float from a data input stream Credit to: - * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java - * - * @param is - * @return - * @throws IOException - */ - public static float readFloat(InputStream is) throws IOException { - byte[] bytes = new byte[4]; - is.read(bytes); - return getFloat(bytes); - } - - /** - * Read a string from a data input stream Credit to: - * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java - * - * @param b - * @return - * @throws IOException - */ - public static float getFloat(byte[] b) { - int accum = 0; - accum = accum | (b[0] & 0xff) << 0; - accum = accum | (b[1] & 0xff) << 8; - accum = accum | (b[2] & 0xff) << 16; - accum = accum | (b[3] & 0xff) << 24; - return Float.intBitsToFloat(accum); - } - - /** - * Read a string from a data input stream Credit to: - * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java - * - * @param dis - * @return - * @throws IOException - */ - public static String readString(DataInputStream dis) throws IOException { - byte[] bytes = new byte[MAX_SIZE]; - byte b = dis.readByte(); - int i = -1; - StringBuilder sb = new StringBuilder(); - while (b != 32 && b != 10) { - i++; - bytes[i] = b; - b = dis.readByte(); - if (i == 49) { - sb.append(new String(bytes, "UTF-8")); - i = -1; - bytes = new byte[MAX_SIZE]; - } - } - sb.append(new String(bytes, 0, i + 1, "UTF-8")); - return sb.toString(); - } - /** * This method writes word vectors to the given path. * Please note: this method doesn't load whole vocab/lookupTable into memory, so it's able to process large vocabularies served over network. @@ -355,7 +367,7 @@ public class WordVectorSerializer { val builder = new StringBuilder(); val l = element.getLabel(); - builder.append(encodeB64(l)).append(" "); + builder.append(ReadHelper.encodeB64(l)).append(" "); val vec = lookupTable.vector(element.getLabel()); for (int i = 0; i < vec.length(); i++) { builder.append(vec.getDouble(i)); @@ -518,7 +530,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" "); for (int code : word.getCodes()) { builder.append(code).append(" "); } @@ -536,7 +548,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" "); for (int point : word.getPoints()) { builder.append(point).append(" "); } @@ -554,7 +566,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ") + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ") .append(word.getElementFrequency()).append(" ") .append(vectors.getVocab().docAppearedIn(word.getLabel())); @@ -638,7 +650,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" "); for (int code : word.getCodes()) { builder.append(code).append(" "); } @@ -656,7 +668,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" "); for (int point : word.getPoints()) { builder.append(point).append(" "); } @@ -677,7 +689,7 @@ public class WordVectorSerializer { StringBuilder builder = new StringBuilder(); for (VocabWord word : vectors.getVocab().tokens()) { if (word.isLabel()) - builder.append(encodeB64(word.getLabel())).append("\n"); + builder.append(ReadHelper.encodeB64(word.getLabel())).append("\n"); } IOUtils.write(builder.toString().trim(), zipfile, StandardCharsets.UTF_8); @@ -688,7 +700,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - builder = new StringBuilder(encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency()) + builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency()) .append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel())); writer.println(builder.toString().trim()); @@ -744,7 +756,7 @@ public class WordVectorSerializer { try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { - VocabWord word = vectors.getVocab().tokenFor(decodeB64(line.trim())); + VocabWord word = vectors.getVocab().tokenFor(ReadHelper.decodeB64(line.trim())); if (word != null) { word.markAsLabel(true); } @@ -836,7 +848,7 @@ public class WordVectorSerializer { String line; while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - VocabWord word = w2v.getVocab().tokenFor(decodeB64(split[0])); + VocabWord word = w2v.getVocab().tokenFor(ReadHelper.decodeB64(split[0])); word.setElementFrequency((long) Double.parseDouble(split[1])); word.setSequencesCount((long) Double.parseDouble(split[2])); } @@ -946,7 +958,7 @@ public class WordVectorSerializer { reader = new BufferedReader(new FileReader(h_points)); while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - VocabWord word = vocab.wordFor(decodeB64(split[0])); + VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0])); List points = new ArrayList<>(); for (int i = 1; i < split.length; i++) { points.add(Integer.parseInt(split[i])); @@ -960,7 +972,7 @@ public class WordVectorSerializer { reader = new BufferedReader(new FileReader(h_codes)); while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - VocabWord word = vocab.wordFor(decodeB64(split[0])); + VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0])); List codes = new ArrayList<>(); for (int i = 1; i < split.length; i++) { codes.add(Byte.parseByte(split[i])); @@ -1704,7 +1716,7 @@ public class WordVectorSerializer { if (line.isEmpty()) line = iter.nextLine(); String[] split = line.split(" "); - String word = decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " "); + String word = ReadHelper.decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " "); VocabWord word1 = new VocabWord(1.0, word); word1.setIndex(cache.numWords()); @@ -1994,7 +2006,13 @@ public class WordVectorSerializer { private static final String SYN1_ENTRY = "syn1.bin"; private static final String SYN1_NEG_ENTRY = "syn1neg.bin"; - + /** + * This method saves specified SequenceVectors model to target OutputStream + * + * @param vectors SequenceVectors model + * @param stream Target output stream + * @param + */ public static void writeSequenceVectors(@NonNull SequenceVectors vectors, @NonNull OutputStream stream) throws IOException { @@ -2040,7 +2058,13 @@ public class WordVectorSerializer { } } - + /** + * This method loads SequenceVectors from specified file path + * + * @param path String + * @param readExtendedTables boolean + * @param + */ public static SequenceVectors readSequenceVectors(@NonNull String path, boolean readExtendedTables) throws IOException { @@ -2050,6 +2074,14 @@ public class WordVectorSerializer { return vectors; } + /** + * This method loads SequenceVectors from specified file path + * + * @param file File + * @param readExtendedTables boolean + * @param + */ + public static SequenceVectors readSequenceVectors(@NonNull File file, boolean readExtendedTables) throws IOException { @@ -2058,6 +2090,13 @@ public class WordVectorSerializer { return vectors; } + /** + * This method loads SequenceVectors from specified input stream + * + * @param stream InputStream + * @param readExtendedTables boolean + * @param + */ public static SequenceVectors readSequenceVectors(@NonNull InputStream stream, boolean readExtendedTables) throws IOException { @@ -2381,6 +2420,12 @@ public class WordVectorSerializer { } } + /** + * This method loads Word2Vec model from binary file + * + * @param file File + * @return Word2Vec + */ public static Word2Vec readAsBinary(@NonNull File file) { boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive(); int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency(); @@ -2403,6 +2448,12 @@ public class WordVectorSerializer { } } + /** + * This method loads Word2Vec model from csv file + * + * @param file File + * @return Word2Vec + */ public static Word2Vec readAsCsv(@NonNull File file) { Word2Vec vec; @@ -2491,7 +2542,7 @@ public class WordVectorSerializer { String line; while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0])); + VocabWord word = new VocabWord(Double.valueOf(split[1]), ReadHelper.decodeB64(split[0])); word.setIndex(cnt.getAndIncrement()); word.incrementSequencesCount(Long.valueOf(split[2])); @@ -2669,7 +2720,7 @@ public class WordVectorSerializer { * * In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment. * - * @param file File should point to previously saved w2v model + * @param inputStream InputStream should point to previously saved w2v model * @return */ public static WordVectors loadStaticModel(InputStream inputStream) throws IOException { @@ -2685,6 +2736,17 @@ public class WordVectorSerializer { } // TODO: this method needs better name :) + /** + * This method restores previously saved w2v model. File can be in one of the following formats: + * 1) Binary model, either compressed or not. Like well-known Google Model + * 2) Popular CSV word2vec text format + * 3) DL4j compressed format + * + * In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment. + * + * @param file File + * @return + */ public static WordVectors loadStaticModel(@NonNull File file) { if (!file.exists() || file.isDirectory()) throw new RuntimeException( @@ -2843,8 +2905,8 @@ public class WordVectorSerializer { throw new RuntimeException(e); } try { - numWords = Integer.parseInt(readString(stream)); - vectorLength = Integer.parseInt(readString(stream)); + numWords = Integer.parseInt(ReadHelper.readString(stream)); + vectorLength = Integer.parseInt(ReadHelper.readString(stream)); } catch (IOException e) { throw new RuntimeException(e); } @@ -2858,13 +2920,13 @@ public class WordVectorSerializer { @Override public Pair next() { try { - String word = readString(stream); + String word = ReadHelper.readString(stream); VocabWord element = new VocabWord(1.0, word); element.setIndex(idxCounter.getAndIncrement()); float[] vector = new float[vectorLength]; for (int i = 0; i < vectorLength; i++) { - vector[i] = readFloat(stream); + vector[i] = ReadHelper.readFloat(stream); } return Pair.makePair(element, vector); @@ -2913,7 +2975,7 @@ public class WordVectorSerializer { String[] split = nextLine.split(" "); - VocabWord word = new VocabWord(1.0, decodeB64(split[0])); + VocabWord word = new VocabWord(1.0, ReadHelper.decodeB64(split[0])); word.setIndex(idxCounter.getAndIncrement()); float[] vector = new float[split.length - 1]; @@ -2937,26 +2999,12 @@ public class WordVectorSerializer { } } - public static String encodeB64(String word) { - try { - return "B64:" + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", ""); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - public static String decodeB64(String word) { - if (word.startsWith("B64:")) { - String arp = word.replaceFirst("B64:", ""); - try { - return new String(Base64.decodeBase64(arp), "UTF-8"); - } catch (Exception e) { - throw new RuntimeException(e); - } - } else - return word; - } - + /** + * This method saves Word2Vec model to output stream + * + * @param word2Vec Word2Vec + * @param stream OutputStream + */ public static void writeWord2Vec(@NonNull Word2Vec word2Vec, @NonNull OutputStream stream) throws IOException { @@ -2968,6 +3016,13 @@ public class WordVectorSerializer { writeSequenceVectors(vectors, stream); } + /** + * This method restores Word2Vec model from file + * + * @param path String + * @param readExtendedTables booleab + * @return Word2Vec + */ public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables) throws IOException { @@ -2976,6 +3031,12 @@ public class WordVectorSerializer { return word2Vec; } + /** + * This method saves table of weights to file + * + * @param weightLookupTable WeightLookupTable + * @param file File + */ public static void writeLookupTable(WeightLookupTable weightLookupTable, @NonNull File file) throws IOException { try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file), @@ -3038,7 +3099,7 @@ public class WordVectorSerializer { headerRead = true; weightLookupTable = new InMemoryLookupTable.Builder().cache(vocabCache).vectorLength(layerSize).build(); } else { - String label = decodeB64(tokens[0]); + String label = ReadHelper.decodeB64(tokens[0]); int freq = Integer.parseInt(tokens[1]); int rows = Integer.parseInt(tokens[2]); int cols = Integer.parseInt(tokens[3]); @@ -3071,6 +3132,13 @@ public class WordVectorSerializer { return weightLookupTable; } + /** + * This method loads Word2Vec model from file + * + * @param file File + * @param readExtendedTables boolean + * @return Word2Vec + */ public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables) throws IOException { @@ -3078,6 +3146,13 @@ public class WordVectorSerializer { return word2Vec; } + /** + * This method loads Word2Vec model from input stream + * + * @param stream InputStream + * @param readExtendedTable boolean + * @return Word2Vec + */ public static Word2Vec readWord2Vec(@NonNull InputStream stream, boolean readExtendedTable) throws IOException { SequenceVectors vectors = readSequenceVectors(stream, readExtendedTable); @@ -3087,7 +3162,13 @@ public class WordVectorSerializer { word2Vec.setModelUtils(vectors.getModelUtils()); return word2Vec; } - + + /** + * This method loads FastText model to file + * + * @param vectors FastText + * @param path File + */ public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException { ObjectOutputStream outputStream = null; try { @@ -3106,6 +3187,11 @@ public class WordVectorSerializer { } } + /** + * This method unloads FastText model from file + * + * @param path File + */ public static FastText readWordVectors(File path) { FastText result = null; try { @@ -3124,6 +3210,13 @@ public class WordVectorSerializer { return result; } + /** + * This method prints memory usage to log + * + * @param numWords + * @param vectorLength + * @param numTables + */ public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) { double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables; @@ -3144,4 +3237,102 @@ public class WordVectorSerializer { OneTimeLogger.info(log, "Projected memory use for model: [{} {}]", String.format("%.2f", value), sfx); } + + /** + * Helper static methods to read data from input stream. + */ + private static class ReadHelper { + /** + * Read a float from a data input stream Credit to: + * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java + * + * @param is + * @return + * @throws IOException + */ + private static float readFloat(InputStream is) throws IOException { + byte[] bytes = new byte[4]; + is.read(bytes); + return getFloat(bytes); + } + + /** + * Read a string from a data input stream Credit to: + * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java + * + * @param b + * @return + * @throws IOException + */ + private static float getFloat(byte[] b) { + int accum = 0; + accum = accum | (b[0] & 0xff) << 0; + accum = accum | (b[1] & 0xff) << 8; + accum = accum | (b[2] & 0xff) << 16; + accum = accum | (b[3] & 0xff) << 24; + return Float.intBitsToFloat(accum); + } + + /** + * Read a string from a data input stream Credit to: + * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java + * + * @param dis + * @return + * @throws IOException + */ + private static String readString(DataInputStream dis) throws IOException { + byte[] bytes = new byte[MAX_SIZE]; + byte b = dis.readByte(); + int i = -1; + StringBuilder sb = new StringBuilder(); + while (b != 32 && b != 10) { + i++; + bytes[i] = b; + b = dis.readByte(); + if (i == 49) { + sb.append(new String(bytes, "UTF-8")); + i = -1; + bytes = new byte[MAX_SIZE]; + } + } + sb.append(new String(bytes, 0, i + 1, "UTF-8")); + return sb.toString(); + } + + private static final String B64 = "B64:"; + + /** + * Encode input string + * + * @param word String + * @return String + */ + private static String encodeB64(String word) { + try { + return B64 + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", ""); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Encode input string + * + * @param word String + * @return String + */ + + private static String decodeB64(String word) { + if (word.startsWith(B64)) { + String arp = word.replaceFirst(B64, ""); + try { + return new String(Base64.decodeBase64(arp), "UTF-8"); + } catch (Exception e) { + throw new RuntimeException(e); + } + } else + return word; + } + } }