diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java index 27d49d5f5..7d6c0f559 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java @@ -856,15 +856,26 @@ public class WordVectorSerializerTest extends BaseDL4JTest { @Test public void testFastText() { - - File[] files = {fastTextRaw, fastTextZip, fastTextGzip}; + File[] files = { fastTextRaw, fastTextZip, fastTextGzip }; for (File file : files) { try { Word2Vec word2Vec = WordVectorSerializer.readAsCsv(file); - assertEquals(99, word2Vec.getVocab().numWords()); + assertEquals(99, word2Vec.getVocab().numWords()); + } catch (Exception readCsvException) { + fail("Failure for input file " + file.getAbsolutePath() + " " + readCsvException.getMessage()); + } + } + } - } catch (Exception e) { - fail("Failure for input file " + file.getAbsolutePath() + " " + e.getMessage()); + @Test + public void testFastText_readWord2VecModel() { + File[] files = { fastTextRaw, fastTextZip, fastTextGzip }; + for (File file : files) { + try { + Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel(file); + assertEquals(99, word2Vec.getVocab().numWords()); + } catch (Exception readCsvException) { + fail("Failure for input file " + file.getAbsolutePath() + " " + readCsvException.getMessage()); } } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java index 8f0003728..a77bdf0de 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,14 +17,45 @@ package org.deeplearning4j.models.embeddings.loader; -import lombok.*; -import lombok.extern.slf4j.Slf4j; +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.zip.GZIPInputStream; +import java.util.zip.ZipEntry; +import java.util.zip.ZipFile; +import java.util.zip.ZipInputStream; +import java.util.zip.ZipOutputStream; + import org.apache.commons.codec.binary.Base64; import org.apache.commons.compress.compressors.gzip.GzipUtils; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.io.LineIterator; import org.apache.commons.io.output.CloseShieldOutputStream; +import org.deeplearning4j.common.util.DL4JFileUtils; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; @@ -50,26 +82,25 @@ import org.deeplearning4j.text.documentiterator.LabelsSource; import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.deeplearning4j.common.util.DL4JFileUtils; +import org.nd4j.common.primitives.Pair; +import org.nd4j.common.util.OneTimeLogger; import org.nd4j.compression.impl.NoOp; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; -import org.nd4j.common.primitives.Pair; import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.MapperFeature; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.SerializationFeature; import org.nd4j.storage.CompressedRamStorage; -import org.nd4j.common.util.OneTimeLogger; -import java.io.*; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.zip.*; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import lombok.val; /** * This is utility class, providing various methods for WordVectors serialization @@ -85,14 +116,17 @@ import java.util.zip.*; * {@link #writeWord2VecModel(Word2Vec, OutputStream)} * *
- * Please note: Only weights will be loaded by this method. - * - * @param file - * @return - */ - public static Word2Vec readWord2VecModel(@NonNull File file) { - return readWord2VecModel(file, false); - } - - /** * This method * 1) Binary model, either compressed or not. Like well-known Google Model @@ -2389,106 +2432,196 @@ public class WordVectorSerializer { * 2) Popular CSV word2vec text format * 3) DL4j compressed format *
- * Please note: if extended data isn't available, only weights will be loaded instead. + * Please note: Only weights will be loaded by this method. * - * @param path - * @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded + * @param path path to model file + * @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded * @return */ public static Word2Vec readWord2VecModel(String path, boolean extendedModel) { return readWord2VecModel(new File(path), extendedModel); } - public static Word2Vec readAsBinaryNoLineBreaks(@NonNull File file) { + /** + * This method + * 1) Binary model, either compressed or not. Like well-known Google Model + * 2) Popular CSV word2vec text format + * 3) DL4j compressed format + *
+ * Please note: Only weights will be loaded by this method. + * + * @param file + * @return + */ + public static Word2Vec readWord2VecModel(File file) { + return readWord2VecModel(file, false); + } + + /** + * This method + * 1) Binary model, either compressed or not. Like well-known Google Model + * 2) Popular CSV word2vec text format + * 3) DL4j compressed format + *
+ * Please note: if extended data isn't available, only weights will be loaded instead.
+ *
+ * @param file model file
+ * @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
+ * @return word2vec model
+ */
+ public static Word2Vec readWord2VecModel(File file, boolean extendedModel) {
+ if (!file.exists() || !file.isFile()) {
+ throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
+ }
+ boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
+ if (originalPeriodic) {
+ Nd4j.getMemoryManager().togglePeriodicGc(false);
+ }
+ Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
+
+ try {
+ return readWord2Vec(file, extendedModel);
+ } catch (Exception readSequenceVectors) {
+ try {
+ return extendedModel
+ ? readAsExtendedModel(file)
+ : readAsSimplifiedModel(file);
+ } catch (Exception loadFromFileException) {
+ try {
+ return readAsCsv(file);
+ } catch (Exception readCsvException) {
+ try {
+ return readAsBinary(file);
+ } catch (Exception readBinaryException) {
+ try {
+ return readAsBinaryNoLineBreaks(file);
+ } catch (Exception readModelException) {
+ log.error("Unable to guess input file format", readModelException);
+ throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
+ }
+ }
+ }
+ }
+ }
+ }
+
+ public static Word2Vec readAsBinaryNoLineBreaks(@NonNull File file) {
+ try (InputStream inputStream = fileStream(file)) {
+ return readAsBinaryNoLineBreaks(inputStream);
+ } catch (IOException readCsvException) {
+ throw new RuntimeException(readCsvException);
+ }
+ }
+
+ public static Word2Vec readAsBinaryNoLineBreaks(@NonNull InputStream inputStream) {
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
- Word2Vec vec;
// try to load without linebreaks
try {
- if (originalPeriodic)
+ if (originalPeriodic) {
Nd4j.getMemoryManager().togglePeriodicGc(true);
+ }
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
- vec = readBinaryModel(file, false, false);
- return vec;
- } catch (Exception ez) {
- throw new RuntimeException(
- "Unable to guess input file format. Please use corresponding loader directly");
+ return readBinaryModel(inputStream, false, false);
+ } catch (Exception readModelException) {
+ log.error("Cannot read binary model", readModelException);
+ throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
+ }
+ }
+
+ public static Word2Vec readAsBinary(@NonNull File file) {
+ try (InputStream inputStream = fileStream(file)) {
+ return readAsBinary(inputStream);
+ } catch (IOException readCsvException) {
+ throw new RuntimeException(readCsvException);
}
}
/**
- * This method loads Word2Vec model from binary file
+ * This method loads Word2Vec model from binary input stream.
*
- * @param file File
- * @return Word2Vec
+ * @param inputStream binary input stream
+ * @return Word2Vec
*/
- public static Word2Vec readAsBinary(@NonNull File file) {
+ public static Word2Vec readAsBinary(@NonNull InputStream inputStream) {
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
- Word2Vec vec;
-
// we fallback to trying binary model instead
try {
log.debug("Trying binary model restoration...");
- if (originalPeriodic)
+ if (originalPeriodic) {
Nd4j.getMemoryManager().togglePeriodicGc(true);
+ }
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
- vec = readBinaryModel(file, true, false);
- return vec;
- } catch (Exception ey) {
- throw new RuntimeException(ey);
+ return readBinaryModel(inputStream, true, false);
+ } catch (Exception readModelException) {
+ throw new RuntimeException(readModelException);
+ }
+ }
+
+ public static Word2Vec readAsCsv(@NonNull File file) {
+ try (InputStream inputStream = fileStream(file)) {
+ return readAsCsv(inputStream);
+ } catch (IOException readCsvException) {
+ throw new RuntimeException(readCsvException);
}
}
/**
* This method loads Word2Vec model from csv file
*
- * @param file File
- * @return Word2Vec
+ * @param inputStream input stream
+ * @return Word2Vec model
*/
- public static Word2Vec readAsCsv(@NonNull File file) {
-
- Word2Vec vec;
+ public static Word2Vec readAsCsv(@NonNull InputStream inputStream) {
VectorsConfiguration configuration = new VectorsConfiguration();
// let's try to load this file as csv file
try {
log.debug("Trying CSV model restoration...");
- Pair
- * Please note: if extended data isn't available, only weights will be loaded instead.
- *
- * @param file
- * @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
- * @return
- */
- public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedModel) {
-
- if (!file.exists() || !file.isFile())
- throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
-
- Word2Vec vec = null;
-
- int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
- boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
- if (originalPeriodic)
- Nd4j.getMemoryManager().togglePeriodicGc(false);
- Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
-
- // try to load zip format
- try {
- vec = readWord2Vec(file, extendedModel);
- return vec;
- } catch (Exception e) {
- // let's try to load this file as csv file
- try {
- if (extendedModel) {
- vec = readAsExtendedModel(file);
- return vec;
- } else {
- vec = readAsSimplifiedModel(file);
- return vec;
- }
- } catch (Exception ex) {
- try {
- vec = readAsCsv(file);
- return vec;
- } catch (Exception exc) {
- try {
- vec = readAsBinary(file);
- return vec;
- } catch (Exception exce) {
- try {
- vec = readAsBinaryNoLineBreaks(file);
- return vec;
-
- } catch (Exception excep) {
- throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
- }
- }
- }
- }
- }
- }
-
protected static TokenizerFactory getTokenizerFactory(VectorsConfiguration configuration) {
if (configuration == null)
return null;
@@ -3019,16 +3091,13 @@ public class WordVectorSerializer {
/**
* This method restores Word2Vec model from file
*
- * @param path String
- * @param readExtendedTables booleab
+ * @param path
+ * @param readExtendedTables
* @return Word2Vec
*/
- public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables)
- throws IOException {
-
+ public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables) {
File file = new File(path);
- Word2Vec word2Vec = readWord2Vec(file, readExtendedTables);
- return word2Vec;
+ return readWord2Vec(file, readExtendedTables);
}
/**
@@ -3139,11 +3208,12 @@ public class WordVectorSerializer {
* @param readExtendedTables boolean
* @return Word2Vec
*/
- public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables)
- throws IOException {
-
- Word2Vec word2Vec = readWord2Vec(new FileInputStream(file), readExtendedTables);
- return word2Vec;
+ public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables) {
+ try (InputStream inputStream = fileStream(file)) {
+ return readWord2Vec(inputStream, readExtendedTables);
+ } catch (Exception readSequenceVectors) {
+ throw new RuntimeException(readSequenceVectors);
+ }
}
/**
@@ -3153,13 +3223,19 @@ public class WordVectorSerializer {
* @param readExtendedTable boolean
* @return Word2Vec
*/
- public static Word2Vec readWord2Vec(@NonNull InputStream stream,
- boolean readExtendedTable) throws IOException {
+ public static Word2Vec readWord2Vec(
+ @NonNull InputStream stream,
+ boolean readExtendedTable) throws IOException {
SequenceVectors