FEATURE: change API of WordVectorSerializer. Add posibility to read models from InputStreams and not only from files

Signed-off-by: hosuaby <alexei.klenin@gmail.com>
master
hosuaby 2020-01-31 16:05:45 +01:00
parent 88ef784b7c
commit dab75fa50b
5 changed files with 464 additions and 354 deletions

View File

@ -856,15 +856,26 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
@Test @Test
public void testFastText() { public void testFastText() {
File[] files = { fastTextRaw, fastTextZip, fastTextGzip };
File[] files = {fastTextRaw, fastTextZip, fastTextGzip};
for (File file : files) { for (File file : files) {
try { try {
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(file); Word2Vec word2Vec = WordVectorSerializer.readAsCsv(file);
assertEquals(99, word2Vec.getVocab().numWords()); assertEquals(99, word2Vec.getVocab().numWords());
} catch (Exception readCsvException) {
fail("Failure for input file " + file.getAbsolutePath() + " " + readCsvException.getMessage());
}
}
}
} catch (Exception e) { @Test
fail("Failure for input file " + file.getAbsolutePath() + " " + e.getMessage()); public void testFastText_readWord2VecModel() {
File[] files = { fastTextRaw, fastTextZip, fastTextGzip };
for (File file : files) {
try {
Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel(file);
assertEquals(99, word2Vec.getVocab().numWords());
} catch (Exception readCsvException) {
fail("Failure for input file " + file.getAbsolutePath() + " " + readCsvException.getMessage());
} }
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,14 +17,45 @@
package org.deeplearning4j.models.embeddings.loader; package org.deeplearning4j.models.embeddings.loader;
import lombok.*; import java.io.BufferedInputStream;
import lombok.extern.slf4j.Slf4j; import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.GZIPInputStream;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
import org.apache.commons.codec.binary.Base64; import org.apache.commons.codec.binary.Base64;
import org.apache.commons.compress.compressors.gzip.GzipUtils; import org.apache.commons.compress.compressors.gzip.GzipUtils;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator; import org.apache.commons.io.LineIterator;
import org.apache.commons.io.output.CloseShieldOutputStream; import org.apache.commons.io.output.CloseShieldOutputStream;
import org.deeplearning4j.common.util.DL4JFileUtils;
import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
@ -50,26 +82,25 @@ import org.deeplearning4j.text.documentiterator.LabelsSource;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess; import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.common.util.DL4JFileUtils; import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.OneTimeLogger;
import org.nd4j.compression.impl.NoOp; import org.nd4j.compression.impl.NoOp;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.common.primitives.Pair;
import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.MapperFeature; import org.nd4j.shade.jackson.databind.MapperFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature; import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.storage.CompressedRamStorage; import org.nd4j.storage.CompressedRamStorage;
import org.nd4j.common.util.OneTimeLogger;
import java.io.*; import lombok.AllArgsConstructor;
import java.nio.charset.StandardCharsets; import lombok.Data;
import java.util.ArrayList; import lombok.NoArgsConstructor;
import java.util.List; import lombok.NonNull;
import java.util.concurrent.atomic.AtomicInteger; import lombok.extern.slf4j.Slf4j;
import java.util.zip.*; import lombok.val;
/** /**
* This is utility class, providing various methods for WordVectors serialization * This is utility class, providing various methods for WordVectors serialization
@ -85,14 +116,17 @@ import java.util.zip.*;
* {@link #writeWord2VecModel(Word2Vec, OutputStream)} * {@link #writeWord2VecModel(Word2Vec, OutputStream)}
* *
* <li>Deserializers for Word2Vec:</li> * <li>Deserializers for Word2Vec:</li>
* {@link #readWord2VecModel(File)}
* {@link #readWord2VecModel(String)} * {@link #readWord2VecModel(String)}
* {@link #readWord2VecModel(File, boolean)}
* {@link #readWord2VecModel(String, boolean)} * {@link #readWord2VecModel(String, boolean)}
* {@link #readWord2VecModel(File)}
* {@link #readWord2VecModel(File, boolean)}
* {@link #readAsBinaryNoLineBreaks(File)} * {@link #readAsBinaryNoLineBreaks(File)}
* {@link #readAsBinaryNoLineBreaks(InputStream)}
* {@link #readAsBinary(File)} * {@link #readAsBinary(File)}
* {@link #readAsBinary(InputStream)}
* {@link #readAsCsv(File)} * {@link #readAsCsv(File)}
* {@link #readBinaryModel(File, boolean, boolean)} * {@link #readAsCsv(InputStream)}
* {@link #readBinaryModel(InputStream, boolean, boolean)}
* {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)} * {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)}
* {@link #readWord2Vec(String, boolean)} * {@link #readWord2Vec(String, boolean)}
* {@link #readWord2Vec(File, boolean)} * {@link #readWord2Vec(File, boolean)}
@ -117,6 +151,7 @@ import java.util.zip.*;
* {@link #fromTableAndVocab(WeightLookupTable, VocabCache)} * {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
* {@link #fromPair(Pair)} * {@link #fromPair(Pair)}
* {@link #loadTxt(File)} * {@link #loadTxt(File)}
* {@link #loadTxt(InputStream)}
* *
* <li>Serializers to tSNE format</li> * <li>Serializers to tSNE format</li>
* {@link #writeTsneFormat(Glove, INDArray, File)} * {@link #writeTsneFormat(Glove, INDArray, File)}
@ -151,6 +186,7 @@ import java.util.zip.*;
* @author Adam Gibson * @author Adam Gibson
* @author raver119 * @author raver119
* @author alexander@skymind.io * @author alexander@skymind.io
* @author Alexei KLENIN
*/ */
@Slf4j @Slf4j
public class WordVectorSerializer { public class WordVectorSerializer {
@ -215,18 +251,22 @@ public class WordVectorSerializer {
}*/ }*/
/** /**
* Read a binary word2vec file. * Read a binary word2vec from input stream.
*
* @param inputStream input stream to read
* @param linebreaks if true, the reader expects each word/vector to be in a separate line, terminated
* by a line break
* @param normalize
* *
* @param modelFile the File to read
* @param linebreaks if true, the reader expects each word/vector to be in a separate line, terminated
* by a line break
* @return a {@link Word2Vec model} * @return a {@link Word2Vec model}
* @throws NumberFormatException * @throws NumberFormatException
* @throws IOException * @throws IOException
* @throws FileNotFoundException * @throws FileNotFoundException
*/ */
public static Word2Vec readBinaryModel(File modelFile, boolean linebreaks, boolean normalize) public static Word2Vec readBinaryModel(
throws NumberFormatException, IOException { InputStream inputStream,
boolean linebreaks,
boolean normalize) throws NumberFormatException, IOException {
InMemoryLookupTable<VocabWord> lookupTable; InMemoryLookupTable<VocabWord> lookupTable;
VocabCache<VocabWord> cache; VocabCache<VocabWord> cache;
INDArray syn0; INDArray syn0;
@ -240,9 +280,7 @@ public class WordVectorSerializer {
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000); Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName()) try (DataInputStream dis = new DataInputStream(inputStream)) {
? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile));
DataInputStream dis = new DataInputStream(bis)) {
words = Integer.parseInt(ReadHelper.readString(dis)); words = Integer.parseInt(ReadHelper.readString(dis));
size = Integer.parseInt(ReadHelper.readString(dis)); size = Integer.parseInt(ReadHelper.readString(dis));
syn0 = Nd4j.create(words, size); syn0 = Nd4j.create(words, size);
@ -250,23 +288,26 @@ public class WordVectorSerializer {
printOutProjectedMemoryUse(words, size, 1); printOutProjectedMemoryUse(words, size, 1);
lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().cache(cache) lookupTable = new InMemoryLookupTable.Builder<VocabWord>()
.useHierarchicSoftmax(false).vectorLength(size).build(); .cache(cache)
.useHierarchicSoftmax(false)
.vectorLength(size)
.build();
int cnt = 0;
String word; String word;
float[] vector = new float[size]; float[] vector = new float[size];
for (int i = 0; i < words; i++) { for (int i = 0; i < words; i++) {
word = ReadHelper.readString(dis); word = ReadHelper.readString(dis);
log.trace("Loading " + word + " with word " + i); log.trace("Loading {} with word {}", word, i);
for (int j = 0; j < size; j++) { for (int j = 0; j < size; j++) {
vector[j] = ReadHelper.readFloat(dis); vector[j] = ReadHelper.readFloat(dis);
} }
if (cache.containsWord(word)) if (cache.containsWord(word)) {
throw new ND4JIllegalStateException("Tried to add existing word. Probably time to switch linebreaks mode?"); throw new ND4JIllegalStateException(
"Tried to add existing word. Probably time to switch linebreaks mode?");
}
syn0.putRow(i, normalize ? Transforms.unitVec(Nd4j.create(vector)) : Nd4j.create(vector)); syn0.putRow(i, normalize ? Transforms.unitVec(Nd4j.create(vector)) : Nd4j.create(vector));
@ -285,25 +326,31 @@ public class WordVectorSerializer {
Nd4j.getMemoryManager().invokeGcOccasionally(); Nd4j.getMemoryManager().invokeGcOccasionally();
} }
} finally { } finally {
if (originalPeriodic) if (originalPeriodic) {
Nd4j.getMemoryManager().togglePeriodicGc(true); Nd4j.getMemoryManager().togglePeriodicGc(true);
}
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq); Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
} }
lookupTable.setSyn0(syn0); lookupTable.setSyn0(syn0);
Word2Vec ret = new Word2Vec
Word2Vec ret = new Word2Vec.Builder().useHierarchicSoftmax(false).resetModel(false).layerSize(syn0.columns()) .Builder()
.allowParallelTokenization(true).elementsLearningAlgorithm(new SkipGram<VocabWord>()) .useHierarchicSoftmax(false)
.learningRate(0.025).windowSize(5).workers(1).build(); .resetModel(false)
.layerSize(syn0.columns())
.allowParallelTokenization(true)
.elementsLearningAlgorithm(new SkipGram<VocabWord>())
.learningRate(0.025)
.windowSize(5)
.workers(1)
.build();
ret.setVocab(cache); ret.setVocab(cache);
ret.setLookupTable(lookupTable); ret.setLookupTable(lookupTable);
return ret; return ret;
} }
/** /**
@ -927,7 +974,7 @@ public class WordVectorSerializer {
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes, public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException { @NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
// first we load syn0 // first we load syn0
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(vectors); Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors));
InMemoryLookupTable lookupTable = pair.getFirst(); InMemoryLookupTable lookupTable = pair.getFirst();
lookupTable.setNegative(configuration.getNegative()); lookupTable.setNegative(configuration.getNegative());
if (configuration.getNegative() > 0) if (configuration.getNegative() > 0)
@ -1604,160 +1651,172 @@ public class WordVectorSerializer {
* @param vectorsFile the path of the file to load\ * @param vectorsFile the path of the file to load\
* @return * @return
* @throws FileNotFoundException if the file does not exist * @throws FileNotFoundException if the file does not exist
* @deprecated Use {@link #loadTxt(File)} * @deprecated Use {@link #loadTxt(InputStream)}
*/ */
@Deprecated @Deprecated
public static WordVectors loadTxtVectors(File vectorsFile) public static WordVectors loadTxtVectors(File vectorsFile) throws IOException {
throws IOException { FileInputStream fileInputStream = new FileInputStream(vectorsFile);
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(vectorsFile); Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(fileInputStream);
return fromPair(pair); return fromPair(pair);
} }
static InputStream fileStream(@NonNull File file) throws IOException {
boolean isZip = file.getName().endsWith(".zip");
boolean isGzip = GzipUtils.isCompressedFilename(file.getName());
InputStream inputStream;
if (isZip) {
inputStream = decompressZip(file);
} else if (isGzip) {
FileInputStream fis = new FileInputStream(file);
inputStream = new GZIPInputStream(fis);
} else {
inputStream = new FileInputStream(file);
}
return new BufferedInputStream(inputStream);
}
private static InputStream decompressZip(File modelFile) throws IOException { private static InputStream decompressZip(File modelFile) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ZipFile zipFile = new ZipFile(modelFile); ZipFile zipFile = new ZipFile(modelFile);
InputStream inputStream = null; InputStream inputStream = null;
try (ZipInputStream zipStream = new ZipInputStream(new BufferedInputStream(new FileInputStream(modelFile)))) { try (FileInputStream fis = new FileInputStream(modelFile);
BufferedInputStream bis = new BufferedInputStream(fis);
ZipEntry entry = null; ZipInputStream zipStream = new ZipInputStream(bis)) {
ZipEntry entry;
if ((entry = zipStream.getNextEntry()) != null) { if ((entry = zipStream.getNextEntry()) != null) {
inputStream = zipFile.getInputStream(entry); inputStream = zipFile.getInputStream(entry);
} }
if (zipStream.getNextEntry() != null) { if (zipStream.getNextEntry() != null) {
throw new RuntimeException("Zip archive " + modelFile + " contains more than 1 file"); throw new RuntimeException("Zip archive " + modelFile + " contains more than 1 file");
} }
} }
return inputStream; return inputStream;
} }
private static BufferedReader createReader(File vectorsFile) throws IOException { public static Pair<InMemoryLookupTable, VocabCache> loadTxt(@NonNull File file) {
InputStreamReader inputStreamReader; try (InputStream inputStream = fileStream(file)) {
try { return loadTxt(inputStream);
inputStreamReader = new InputStreamReader(decompressZip(vectorsFile)); } catch (IOException readTestException) {
} catch (IOException e) { throw new RuntimeException(readTestException);
inputStreamReader = new InputStreamReader(GzipUtils.isCompressedFilename(vectorsFile.getName())
? new GZIPInputStream(new FileInputStream(vectorsFile))
: new FileInputStream(vectorsFile), "UTF-8");
} }
BufferedReader reader = new BufferedReader(inputStreamReader);
return reader;
} }
/** /**
* Loads an in memory cache from the given path (sets syn0 and the vocab) * Loads an in memory cache from the given input stream (sets syn0 and the vocab).
* *
* @param vectorsFile the path of the file to load * @param inputStream input stream
* @return a Pair holding the lookup table and the vocab cache. * @return a {@link Pair} holding the lookup table and the vocab cache.
* @throws FileNotFoundException if the input file does not exist
*/ */
public static Pair<InMemoryLookupTable, VocabCache> loadTxt(File vectorsFile) public static Pair<InMemoryLookupTable, VocabCache> loadTxt(@NonNull InputStream inputStream) {
throws IOException, UnsupportedEncodingException { AbstractCache<VocabWord> cache = new AbstractCache<>();
LineIterator lines = null;
AbstractCache cache = new AbstractCache<>(); try (InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
BufferedReader reader = createReader(vectorsFile); BufferedReader reader = new BufferedReader(inputStreamReader)) {
LineIterator iter = IOUtils.lineIterator(reader); lines = IOUtils.lineIterator(reader);
String line = null;
boolean hasHeader = false;
if (iter.hasNext()) {
line = iter.nextLine(); // skip header line
//look for spaces
if (!line.contains(" ")) {
log.debug("Skipping first line");
hasHeader = true;
} else {
// we should check for something that looks like proper word vectors here. i.e: 1 word at the 0 position, and bunch of floats further
String[] split = line.split(" ");
try {
long[] header = new long[split.length];
for (int x = 0; x < split.length; x++) {
header[x] = Long.parseLong(split[x]);
}
if (split.length < 4)
hasHeader = true;
// now we know, if that's all ints - it's just a header
// [0] - number of words
// [1] - vectorSize
// [2] - number of documents <-- DL4j-only value
if (split.length == 3)
cache.incrementTotalDocCount(header[2]);
printOutProjectedMemoryUse(header[0], (int) header[1], 1); String line = null;
boolean hasHeader = false;
hasHeader = true; /* Check if first line is a header */
if (lines.hasNext()) {
line = lines.nextLine();
hasHeader = isHeader(line, cache);
}
try { if (hasHeader) {
reader.close(); log.debug("First line is a header");
} catch (Exception ex) { line = lines.nextLine();
} }
} catch (Exception e) {
// if any conversion exception hits - that'll be considered header
hasHeader = false;
List<INDArray> arrays = new ArrayList<>();
long[] vShape = new long[]{ 1, -1 };
do {
String[] tokens = line.split(" ");
String word = ReadHelper.decodeB64(tokens[0]);
VocabWord vocabWord = new VocabWord(1.0, word);
vocabWord.setIndex(cache.numWords());
cache.addToken(vocabWord);
cache.addWordToIndex(vocabWord.getIndex(), word);
cache.putVocabWord(word);
float[] vector = new float[tokens.length - 1];
for (int i = 1; i < tokens.length; i++) {
vector[i - 1] = Float.parseFloat(tokens[i]);
} }
vShape[1] = vector.length;
INDArray row = Nd4j.create(vector, vShape);
arrays.add(row);
line = lines.hasNext() ? lines.next() : null;
} while (line != null);
INDArray syn = Nd4j.vstack(arrays);
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
.Builder<VocabWord>()
.vectorLength(arrays.get(0).columns())
.useAdaGrad(false)
.cache(cache)
.useHierarchicSoftmax(false)
.build();
lookupTable.setSyn0(syn);
return new Pair<>((InMemoryLookupTable) lookupTable, (VocabCache) cache);
} catch (IOException readeTextStreamException) {
throw new RuntimeException(readeTextStreamException);
} finally {
if (lines != null) {
lines.close();
} }
} }
}
//reposition buffer to be one line ahead static boolean isHeader(String line, AbstractCache cache) {
if (hasHeader) { if (!line.contains(" ")) {
line = ""; return true;
iter.close(); } else {
//reader = new BufferedReader(new FileReader(vectorsFile));
reader = createReader(vectorsFile);
iter = IOUtils.lineIterator(reader);
iter.nextLine();
}
List<INDArray> arrays = new ArrayList<>(); /* We should check for something that looks like proper word vectors here. i.e: 1 word at the 0
long[] vShape = new long[]{1, -1}; * position, and bunch of floats further */
while (iter.hasNext()) { String[] headers = line.split(" ");
if (line.isEmpty())
line = iter.nextLine();
String[] split = line.split(" ");
String word = ReadHelper.decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
VocabWord word1 = new VocabWord(1.0, word);
word1.setIndex(cache.numWords()); try {
long[] header = new long[headers.length];
for (int x = 0; x < headers.length; x++) {
header[x] = Long.parseLong(headers[x]);
}
cache.addToken(word1); /* Now we know, if that's all ints - it's just a header
* [0] - number of words
* [1] - vectorLength
* [2] - number of documents <-- DL4j-only value
*/
if (headers.length == 3) {
long numberOfDocuments = header[2];
cache.incrementTotalDocCount(numberOfDocuments);
}
cache.addWordToIndex(word1.getIndex(), word); long numWords = header[0];
int vectorLength = (int) header[1];
printOutProjectedMemoryUse(numWords, vectorLength, 1);
cache.putVocabWord(word); return true;
} catch (Exception notHeaderException) {
float[] vector = new float[split.length - 1]; // if any conversion exception hits - that'll be considered header
return false;
for (int i = 1; i < split.length; i++) {
vector[i - 1] = Float.parseFloat(split[i]);
} }
vShape[1] = vector.length;
INDArray row = Nd4j.create(vector, vShape);
arrays.add(row);
// workaround for skipped first row
line = "";
} }
INDArray syn = Nd4j.vstack(arrays);
InMemoryLookupTable lookupTable =
(InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(arrays.get(0).columns())
.useAdaGrad(false).cache(cache).useHierarchicSoftmax(false).build();
lookupTable.setSyn0(syn);
iter.close();
try {
reader.close();
} catch (Exception e) {
}
return new Pair<>(lookupTable, (VocabCache) cache);
} }
/** /**
@ -2352,22 +2411,6 @@ public class WordVectorSerializer {
} }
} }
/**
* This method
* 1) Binary model, either compressed or not. Like well-known Google Model
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
* <p>
* Please note: Only weights will be loaded by this method.
*
* @param file
* @return
*/
public static Word2Vec readWord2VecModel(@NonNull File file) {
return readWord2VecModel(file, false);
}
/** /**
* This method * This method
* 1) Binary model, either compressed or not. Like well-known Google Model * 1) Binary model, either compressed or not. Like well-known Google Model
@ -2389,106 +2432,196 @@ public class WordVectorSerializer {
* 2) Popular CSV word2vec text format * 2) Popular CSV word2vec text format
* 3) DL4j compressed format * 3) DL4j compressed format
* <p> * <p>
* Please note: if extended data isn't available, only weights will be loaded instead. * Please note: Only weights will be loaded by this method.
* *
* @param path * @param path path to model file
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded * @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
* @return * @return
*/ */
public static Word2Vec readWord2VecModel(String path, boolean extendedModel) { public static Word2Vec readWord2VecModel(String path, boolean extendedModel) {
return readWord2VecModel(new File(path), extendedModel); return readWord2VecModel(new File(path), extendedModel);
} }
public static Word2Vec readAsBinaryNoLineBreaks(@NonNull File file) { /**
* This method
* 1) Binary model, either compressed or not. Like well-known Google Model
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
* <p>
* Please note: Only weights will be loaded by this method.
*
* @param file
* @return
*/
public static Word2Vec readWord2VecModel(File file) {
return readWord2VecModel(file, false);
}
/**
* This method
* 1) Binary model, either compressed or not. Like well-known Google Model
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
* <p>
* Please note: if extended data isn't available, only weights will be loaded instead.
*
* @param file model file
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
* @return word2vec model
*/
public static Word2Vec readWord2VecModel(File file, boolean extendedModel) {
if (!file.exists() || !file.isFile()) {
throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
}
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
if (originalPeriodic) {
Nd4j.getMemoryManager().togglePeriodicGc(false);
}
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
try {
return readWord2Vec(file, extendedModel);
} catch (Exception readSequenceVectors) {
try {
return extendedModel
? readAsExtendedModel(file)
: readAsSimplifiedModel(file);
} catch (Exception loadFromFileException) {
try {
return readAsCsv(file);
} catch (Exception readCsvException) {
try {
return readAsBinary(file);
} catch (Exception readBinaryException) {
try {
return readAsBinaryNoLineBreaks(file);
} catch (Exception readModelException) {
log.error("Unable to guess input file format", readModelException);
throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
}
}
}
}
}
}
public static Word2Vec readAsBinaryNoLineBreaks(@NonNull File file) {
try (InputStream inputStream = fileStream(file)) {
return readAsBinaryNoLineBreaks(inputStream);
} catch (IOException readCsvException) {
throw new RuntimeException(readCsvException);
}
}
public static Word2Vec readAsBinaryNoLineBreaks(@NonNull InputStream inputStream) {
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive(); boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency(); int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
Word2Vec vec;
// try to load without linebreaks // try to load without linebreaks
try { try {
if (originalPeriodic) if (originalPeriodic) {
Nd4j.getMemoryManager().togglePeriodicGc(true); Nd4j.getMemoryManager().togglePeriodicGc(true);
}
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq); Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
vec = readBinaryModel(file, false, false); return readBinaryModel(inputStream, false, false);
return vec; } catch (Exception readModelException) {
} catch (Exception ez) { log.error("Cannot read binary model", readModelException);
throw new RuntimeException( throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
"Unable to guess input file format. Please use corresponding loader directly"); }
}
public static Word2Vec readAsBinary(@NonNull File file) {
try (InputStream inputStream = fileStream(file)) {
return readAsBinary(inputStream);
} catch (IOException readCsvException) {
throw new RuntimeException(readCsvException);
} }
} }
/** /**
* This method loads Word2Vec model from binary file * This method loads Word2Vec model from binary input stream.
* *
* @param file File * @param inputStream binary input stream
* @return Word2Vec * @return Word2Vec
*/ */
public static Word2Vec readAsBinary(@NonNull File file) { public static Word2Vec readAsBinary(@NonNull InputStream inputStream) {
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive(); boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency(); int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
Word2Vec vec;
// we fallback to trying binary model instead // we fallback to trying binary model instead
try { try {
log.debug("Trying binary model restoration..."); log.debug("Trying binary model restoration...");
if (originalPeriodic) if (originalPeriodic) {
Nd4j.getMemoryManager().togglePeriodicGc(true); Nd4j.getMemoryManager().togglePeriodicGc(true);
}
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq); Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
vec = readBinaryModel(file, true, false); return readBinaryModel(inputStream, true, false);
return vec; } catch (Exception readModelException) {
} catch (Exception ey) { throw new RuntimeException(readModelException);
throw new RuntimeException(ey); }
}
public static Word2Vec readAsCsv(@NonNull File file) {
try (InputStream inputStream = fileStream(file)) {
return readAsCsv(inputStream);
} catch (IOException readCsvException) {
throw new RuntimeException(readCsvException);
} }
} }
/** /**
* This method loads Word2Vec model from csv file * This method loads Word2Vec model from csv file
* *
* @param file File * @param inputStream input stream
* @return Word2Vec * @return Word2Vec model
*/ */
public static Word2Vec readAsCsv(@NonNull File file) { public static Word2Vec readAsCsv(@NonNull InputStream inputStream) {
Word2Vec vec;
VectorsConfiguration configuration = new VectorsConfiguration(); VectorsConfiguration configuration = new VectorsConfiguration();
// let's try to load this file as csv file // let's try to load this file as csv file
try { try {
log.debug("Trying CSV model restoration..."); log.debug("Trying CSV model restoration...");
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(file); Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(inputStream);
Word2Vec.Builder builder = new Word2Vec.Builder().lookupTable(pair.getFirst()).useAdaGrad(false) Word2Vec.Builder builder = new Word2Vec
.vocabCache(pair.getSecond()).layerSize(pair.getFirst().layerSize()) .Builder()
.lookupTable(pair.getFirst())
.useAdaGrad(false)
.vocabCache(pair.getSecond())
.layerSize(pair.getFirst().layerSize())
// we don't use hs here, because model is incomplete // we don't use hs here, because model is incomplete
.useHierarchicSoftmax(false).resetModel(false); .useHierarchicSoftmax(false)
.resetModel(false);
TokenizerFactory factory = getTokenizerFactory(configuration); TokenizerFactory factory = getTokenizerFactory(configuration);
if (factory != null) if (factory != null) {
builder.tokenizerFactory(factory); builder.tokenizerFactory(factory);
}
vec = builder.build(); return builder.build();
return vec;
} catch (Exception ex) { } catch (Exception ex) {
throw new RuntimeException("Unable to load model in CSV format"); throw new RuntimeException("Unable to load model in CSV format");
} }
} }
/**
* This method just loads full compressed model.
*/
private static Word2Vec readAsExtendedModel(@NonNull File file) throws IOException { private static Word2Vec readAsExtendedModel(@NonNull File file) throws IOException {
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency(); int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive(); boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
log.debug("Trying full model restoration..."); log.debug("Trying full model restoration...");
// this method just loads full compressed model
if (originalPeriodic) if (originalPeriodic) {
Nd4j.getMemoryManager().togglePeriodicGc(true); Nd4j.getMemoryManager().togglePeriodicGc(true);
}
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq); Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
@ -2627,67 +2760,6 @@ public class WordVectorSerializer {
return vec; return vec;
} }
/**
* This method
* 1) Binary model, either compressed or not. Like well-known Google Model
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
* <p>
* Please note: if extended data isn't available, only weights will be loaded instead.
*
* @param file
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
* @return
*/
public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedModel) {
if (!file.exists() || !file.isFile())
throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
Word2Vec vec = null;
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
if (originalPeriodic)
Nd4j.getMemoryManager().togglePeriodicGc(false);
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
// try to load zip format
try {
vec = readWord2Vec(file, extendedModel);
return vec;
} catch (Exception e) {
// let's try to load this file as csv file
try {
if (extendedModel) {
vec = readAsExtendedModel(file);
return vec;
} else {
vec = readAsSimplifiedModel(file);
return vec;
}
} catch (Exception ex) {
try {
vec = readAsCsv(file);
return vec;
} catch (Exception exc) {
try {
vec = readAsBinary(file);
return vec;
} catch (Exception exce) {
try {
vec = readAsBinaryNoLineBreaks(file);
return vec;
} catch (Exception excep) {
throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
}
}
}
}
}
}
protected static TokenizerFactory getTokenizerFactory(VectorsConfiguration configuration) { protected static TokenizerFactory getTokenizerFactory(VectorsConfiguration configuration) {
if (configuration == null) if (configuration == null)
return null; return null;
@ -3019,16 +3091,13 @@ public class WordVectorSerializer {
/** /**
* This method restores Word2Vec model from file * This method restores Word2Vec model from file
* *
* @param path String * @param path
* @param readExtendedTables booleab * @param readExtendedTables
* @return Word2Vec * @return Word2Vec
*/ */
public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables) public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables) {
throws IOException {
File file = new File(path); File file = new File(path);
Word2Vec word2Vec = readWord2Vec(file, readExtendedTables); return readWord2Vec(file, readExtendedTables);
return word2Vec;
} }
/** /**
@ -3139,11 +3208,12 @@ public class WordVectorSerializer {
* @param readExtendedTables boolean * @param readExtendedTables boolean
* @return Word2Vec * @return Word2Vec
*/ */
public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables) public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables) {
throws IOException { try (InputStream inputStream = fileStream(file)) {
return readWord2Vec(inputStream, readExtendedTables);
Word2Vec word2Vec = readWord2Vec(new FileInputStream(file), readExtendedTables); } catch (Exception readSequenceVectors) {
return word2Vec; throw new RuntimeException(readSequenceVectors);
}
} }
/** /**
@ -3153,13 +3223,19 @@ public class WordVectorSerializer {
* @param readExtendedTable boolean * @param readExtendedTable boolean
* @return Word2Vec * @return Word2Vec
*/ */
public static Word2Vec readWord2Vec(@NonNull InputStream stream, public static Word2Vec readWord2Vec(
boolean readExtendedTable) throws IOException { @NonNull InputStream stream,
boolean readExtendedTable) throws IOException {
SequenceVectors<VocabWord> vectors = readSequenceVectors(stream, readExtendedTable); SequenceVectors<VocabWord> vectors = readSequenceVectors(stream, readExtendedTable);
Word2Vec word2Vec = new Word2Vec.Builder(vectors.getConfiguration()).layerSize(vectors.getLayerSize()).build();
Word2Vec word2Vec = new Word2Vec
.Builder(vectors.getConfiguration())
.layerSize(vectors.getLayerSize())
.build();
word2Vec.setVocab(vectors.getVocab()); word2Vec.setVocab(vectors.getVocab());
word2Vec.setLookupTable(vectors.lookupTable()); word2Vec.setLookupTable(vectors.lookupTable());
word2Vec.setModelUtils(vectors.getModelUtils()); word2Vec.setModelUtils(vectors.getModelUtils());
return word2Vec; return word2Vec;
} }

View File

@ -37,8 +37,6 @@ import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals;
@Slf4j @Slf4j
public class TsneTest extends BaseDL4JTest { public class TsneTest extends BaseDL4JTest {

View File

@ -14,17 +14,14 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.models.sequencevectors.serialization; package org.deeplearning4j.models.embeddings.loader;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.commons.lang.StringUtils;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW; import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils; import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils; import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
import org.deeplearning4j.models.fasttext.FastText; import org.deeplearning4j.models.fasttext.FastText;
@ -47,7 +44,11 @@ import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import static org.junit.Assert.*; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@Slf4j @Slf4j
public class WordVectorSerializerTest extends BaseDL4JTest { public class WordVectorSerializerTest extends BaseDL4JTest {
@ -78,10 +79,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2), syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2); syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
InMemoryLookupTable<VocabWord> lookupTable = InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>() .Builder<VocabWord>()
.useAdaGrad(false).cache(cache) .useAdaGrad(false)
.build(); .cache(cache)
.build();
lookupTable.setSyn0(syn0); lookupTable.setSyn0(syn0);
lookupTable.setSyn1(syn1); lookupTable.setSyn1(syn1);
@ -92,7 +94,6 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
lookupTable(lookupTable). lookupTable(lookupTable).
build(); build();
SequenceVectors<VocabWord> deser = null; SequenceVectors<VocabWord> deser = null;
String json = StringUtils.EMPTY;
try { try {
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
WordVectorSerializer.writeSequenceVectors(vectors, baos); WordVectorSerializer.writeSequenceVectors(vectors, baos);
@ -126,10 +127,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2), syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2); syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
InMemoryLookupTable<VocabWord> lookupTable = InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>() .Builder<VocabWord>()
.useAdaGrad(false).cache(cache) .useAdaGrad(false)
.build(); .cache(cache)
.build();
lookupTable.setSyn0(syn0); lookupTable.setSyn0(syn0);
lookupTable.setSyn1(syn1); lookupTable.setSyn1(syn1);
@ -204,10 +206,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2), syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2); syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
InMemoryLookupTable<VocabWord> lookupTable = InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>() .Builder<VocabWord>()
.useAdaGrad(false).cache(cache) .useAdaGrad(false)
.build(); .cache(cache)
.build();
lookupTable.setSyn0(syn0); lookupTable.setSyn0(syn0);
lookupTable.setSyn1(syn1); lookupTable.setSyn1(syn1);
@ -252,10 +255,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2), syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2); syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
InMemoryLookupTable<VocabWord> lookupTable = InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>() .Builder<VocabWord>()
.useAdaGrad(false).cache(cache) .useAdaGrad(false)
.build(); .cache(cache)
.build();
lookupTable.setSyn0(syn0); lookupTable.setSyn0(syn0);
lookupTable.setSyn1(syn1); lookupTable.setSyn1(syn1);
@ -267,7 +271,6 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
WeightLookupTable<VocabWord> deser = null; WeightLookupTable<VocabWord> deser = null;
try { try {
WordVectorSerializer.writeLookupTable(lookupTable, file); WordVectorSerializer.writeLookupTable(lookupTable, file);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
deser = WordVectorSerializer.readLookupTable(file); deser = WordVectorSerializer.readLookupTable(file);
} catch (Exception e) { } catch (Exception e) {
log.error("",e); log.error("",e);
@ -305,7 +308,6 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
FastText deser = null; FastText deser = null;
try { try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
deser = WordVectorSerializer.readWordVectors(new File(dir, "some.data")); deser = WordVectorSerializer.readWordVectors(new File(dir, "some.data"));
} catch (Exception e) { } catch (Exception e) {
log.error("",e); log.error("",e);
@ -323,4 +325,32 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
assertEquals(fastText.getInputFile(), deser.getInputFile()); assertEquals(fastText.getInputFile(), deser.getInputFile());
assertEquals(fastText.getOutputFile(), deser.getOutputFile()); assertEquals(fastText.getOutputFile(), deser.getOutputFile());
} }
@Test
public void testIsHeader_withValidHeader () {
/* Given */
AbstractCache<VocabWord> cache = new AbstractCache<>();
String line = "48 100";
/* When */
boolean isHeader = WordVectorSerializer.isHeader(line, cache);
/* Then */
assertTrue(isHeader);
}
@Test
public void testIsHeader_notHeader () {
/* Given */
AbstractCache<VocabWord> cache = new AbstractCache<>();
String line = "your -0.0017603 0.0030831 0.00069072 0.0020581 -0.0050952 -2.2573e-05 -0.001141";
/* When */
boolean isHeader = WordVectorSerializer.isHeader(line, cache);
/* Then */
assertFalse(isHeader);
}
} }

View File

@ -1,9 +1,9 @@
package org.deeplearning4j.models.fasttext; package org.deeplearning4j.models.fasttext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec; import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator; import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.junit.Rule; import org.junit.Rule;
@ -14,13 +14,14 @@ import org.nd4j.common.primitives.Pair;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import java.io.File; import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException; import java.io.IOException;
import static org.hamcrest.CoreMatchers.hasItems;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@Slf4j @Slf4j
public class FastTextTest extends BaseDL4JTest { public class FastTextTest extends BaseDL4JTest {
@ -32,7 +33,6 @@ public class FastTextTest extends BaseDL4JTest {
private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin"); private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin");
private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec"); private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec");
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();
@ -90,7 +90,7 @@ public class FastTextTest extends BaseDL4JTest {
} }
@Test @Test
public void tesLoadCBOWModel() throws IOException { public void tesLoadCBOWModel() {
FastText fastText = new FastText(cbowModelFile); FastText fastText = new FastText(cbowModelFile);
fastText.test(cbowModelFile); fastText.test(cbowModelFile);
@ -155,7 +155,7 @@ public class FastTextTest extends BaseDL4JTest {
} }
@Test @Test
public void testVocabulary() throws IOException { public void testVocabulary() {
FastText fastText = new FastText(supModelFile); FastText fastText = new FastText(supModelFile);
assertEquals(48, fastText.vocab().numWords()); assertEquals(48, fastText.vocab().numWords());
assertEquals(48, fastText.vocabSize()); assertEquals(48, fastText.vocabSize());
@ -171,78 +171,73 @@ public class FastTextTest extends BaseDL4JTest {
} }
@Test @Test
public void testLoadIterator() { public void testLoadIterator() throws FileNotFoundException {
try { SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); FastText
FastText fastText = .builder()
FastText.builder().supervised(true).iterator(iter).build(); .supervised(true)
fastText.loadIterator(); .iterator(iter)
.build()
} catch (IOException e) { .loadIterator();
log.error("",e);
}
} }
@Test(expected=IllegalStateException.class) @Test(expected=IllegalStateException.class)
public void testState() { public void testState() {
FastText fastText = new FastText(); FastText fastText = new FastText();
String label = fastText.predict("something"); fastText.predict("something");
} }
@Test @Test
public void testPretrainedVectors() throws IOException { public void testPretrainedVectors() throws IOException {
File output = testDir.newFile(); File output = testDir.newFile();
FastText fastText = FastText fastText = FastText
FastText.builder().supervised(true). .builder()
inputFile(inputFile.getAbsolutePath()). .supervised(true)
pretrainedVectorsFile(supervisedVectors.getAbsolutePath()). .inputFile(inputFile.getAbsolutePath())
outputFile(output.getAbsolutePath()).build(); .pretrainedVectorsFile(supervisedVectors.getAbsolutePath())
.outputFile(output.getAbsolutePath())
.build();
log.info("\nTraining supervised model ...\n"); log.info("\nTraining supervised model ...\n");
fastText.fit(); fastText.fit();
} }
@Test @Test
public void testWordsStatistics() throws IOException { public void testWordsStatistics() throws IOException {
File output = testDir.newFile(); File output = testDir.newFile();
FastText fastText = FastText fastText = FastText
FastText.builder().supervised(true). .builder()
inputFile(inputFile.getAbsolutePath()). .supervised(true)
outputFile(output.getAbsolutePath()).build(); .inputFile(inputFile.getAbsolutePath())
.outputFile(output.getAbsolutePath())
.build();
log.info("\nTraining supervised model ...\n"); log.info("\nTraining supervised model ...\n");
fastText.fit(); fastText.fit();
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(new File(output.getAbsolutePath() + ".vec")); File file = new File(output.getAbsolutePath() + ".vec");
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(file);
assertEquals(48, word2Vec.getVocab().numWords()); assertEquals(48, word2Vec.getVocab().numWords());
assertEquals("", 0.1667751520872116, word2Vec.similarity("Football", "teams"), 1e-4);
System.out.println(word2Vec.wordsNearest("association", 3)); assertEquals("", 0.10083991289138794, word2Vec.similarity("professional", "minutes"), 1e-4);
System.out.println(word2Vec.similarity("Football", "teams")); assertEquals("", Double.NaN, word2Vec.similarity("java","cpp"), 0.0);
System.out.println(word2Vec.similarity("professional", "minutes")); assertThat(word2Vec.wordsNearest("association", 3), hasItems("Football", "Soccer", "men's"));
System.out.println(word2Vec.similarity("java","cpp"));
} }
@Test @Test
public void testWordsNativeStatistics() throws IOException { public void testWordsNativeStatistics() {
File output = testDir.newFile();
FastText fastText = new FastText(); FastText fastText = new FastText();
fastText.loadPretrainedVectors(supervisedVectors); fastText.loadPretrainedVectors(supervisedVectors);
log.info("\nTraining supervised model ...\n"); log.info("\nTraining supervised model ...\n");
assertEquals(48, fastText.vocab().numWords()); assertEquals(48, fastText.vocab().numWords());
assertThat(fastText.wordsNearest("association", 3), hasItems("most","eleven","hours"));
String[] result = new String[3];
fastText.wordsNearest("association", 3).toArray(result);
assertArrayEquals(new String[]{"most","eleven","hours"}, result);
assertEquals(0.1657, fastText.similarity("Football", "teams"), 1e-4); assertEquals(0.1657, fastText.similarity("Football", "teams"), 1e-4);
assertEquals(0.3661, fastText.similarity("professional", "minutes"), 1e-4); assertEquals(0.3661, fastText.similarity("professional", "minutes"), 1e-4);
assertEquals(Double.NaN, fastText.similarity("java","cpp"), 1e-4); assertEquals(Double.NaN, fastText.similarity("java","cpp"), 0.0);
} }
} }