Merge pull request #8908 from hosuaby/feature/loadModelFromStream
FEATURE: change API of WordVectorSerializer. Add posibility to read m…master
commit
58fe365c21
|
@ -856,15 +856,26 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFastText() {
|
public void testFastText() {
|
||||||
|
File[] files = { fastTextRaw, fastTextZip, fastTextGzip };
|
||||||
File[] files = {fastTextRaw, fastTextZip, fastTextGzip};
|
|
||||||
for (File file : files) {
|
for (File file : files) {
|
||||||
try {
|
try {
|
||||||
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(file);
|
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(file);
|
||||||
assertEquals(99, word2Vec.getVocab().numWords());
|
assertEquals(99, word2Vec.getVocab().numWords());
|
||||||
|
} catch (Exception readCsvException) {
|
||||||
|
fail("Failure for input file " + file.getAbsolutePath() + " " + readCsvException.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} catch (Exception e) {
|
@Test
|
||||||
fail("Failure for input file " + file.getAbsolutePath() + " " + e.getMessage());
|
public void testFastText_readWord2VecModel() {
|
||||||
|
File[] files = { fastTextRaw, fastTextZip, fastTextGzip };
|
||||||
|
for (File file : files) {
|
||||||
|
try {
|
||||||
|
Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel(file);
|
||||||
|
assertEquals(99, word2Vec.getVocab().numWords());
|
||||||
|
} catch (Exception readCsvException) {
|
||||||
|
fail("Failure for input file " + file.getAbsolutePath() + " " + readCsvException.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,6 +84,12 @@
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.awaitility</groupId>
|
||||||
|
<artifactId>awaitility</artifactId>
|
||||||
|
<version>4.0.2</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,14 +17,45 @@
|
||||||
|
|
||||||
package org.deeplearning4j.models.embeddings.loader;
|
package org.deeplearning4j.models.embeddings.loader;
|
||||||
|
|
||||||
import lombok.*;
|
import java.io.BufferedInputStream;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import java.io.BufferedOutputStream;
|
||||||
|
import java.io.BufferedReader;
|
||||||
|
import java.io.BufferedWriter;
|
||||||
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.io.DataInputStream;
|
||||||
|
import java.io.DataOutputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.FileReader;
|
||||||
|
import java.io.FileWriter;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.InputStreamReader;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
import java.io.ObjectOutputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.io.OutputStreamWriter;
|
||||||
|
import java.io.PrintWriter;
|
||||||
|
import java.io.UnsupportedEncodingException;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
import java.util.zip.GZIPInputStream;
|
||||||
|
import java.util.zip.ZipEntry;
|
||||||
|
import java.util.zip.ZipFile;
|
||||||
|
import java.util.zip.ZipInputStream;
|
||||||
|
import java.util.zip.ZipOutputStream;
|
||||||
|
|
||||||
import org.apache.commons.codec.binary.Base64;
|
import org.apache.commons.codec.binary.Base64;
|
||||||
import org.apache.commons.compress.compressors.gzip.GzipUtils;
|
import org.apache.commons.compress.compressors.gzip.GzipUtils;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.apache.commons.io.LineIterator;
|
import org.apache.commons.io.LineIterator;
|
||||||
import org.apache.commons.io.output.CloseShieldOutputStream;
|
import org.apache.commons.io.output.CloseShieldOutputStream;
|
||||||
|
import org.deeplearning4j.common.util.DL4JFileUtils;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
|
@ -50,26 +82,25 @@ import org.deeplearning4j.text.documentiterator.LabelsSource;
|
||||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||||
import org.deeplearning4j.common.util.DL4JFileUtils;
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.common.util.OneTimeLogger;
|
||||||
import org.nd4j.compression.impl.NoOp;
|
import org.nd4j.compression.impl.NoOp;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||||
import org.nd4j.shade.jackson.databind.MapperFeature;
|
import org.nd4j.shade.jackson.databind.MapperFeature;
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||||
import org.nd4j.storage.CompressedRamStorage;
|
import org.nd4j.storage.CompressedRamStorage;
|
||||||
import org.nd4j.common.util.OneTimeLogger;
|
|
||||||
|
|
||||||
import java.io.*;
|
import lombok.AllArgsConstructor;
|
||||||
import java.nio.charset.StandardCharsets;
|
import lombok.Data;
|
||||||
import java.util.ArrayList;
|
import lombok.NoArgsConstructor;
|
||||||
import java.util.List;
|
import lombok.NonNull;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import java.util.zip.*;
|
import lombok.val;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This is utility class, providing various methods for WordVectors serialization
|
* This is utility class, providing various methods for WordVectors serialization
|
||||||
|
@ -85,14 +116,17 @@ import java.util.zip.*;
|
||||||
* {@link #writeWord2VecModel(Word2Vec, OutputStream)}
|
* {@link #writeWord2VecModel(Word2Vec, OutputStream)}
|
||||||
*
|
*
|
||||||
* <li>Deserializers for Word2Vec:</li>
|
* <li>Deserializers for Word2Vec:</li>
|
||||||
* {@link #readWord2VecModel(File)}
|
|
||||||
* {@link #readWord2VecModel(String)}
|
* {@link #readWord2VecModel(String)}
|
||||||
* {@link #readWord2VecModel(File, boolean)}
|
|
||||||
* {@link #readWord2VecModel(String, boolean)}
|
* {@link #readWord2VecModel(String, boolean)}
|
||||||
|
* {@link #readWord2VecModel(File)}
|
||||||
|
* {@link #readWord2VecModel(File, boolean)}
|
||||||
* {@link #readAsBinaryNoLineBreaks(File)}
|
* {@link #readAsBinaryNoLineBreaks(File)}
|
||||||
|
* {@link #readAsBinaryNoLineBreaks(InputStream)}
|
||||||
* {@link #readAsBinary(File)}
|
* {@link #readAsBinary(File)}
|
||||||
|
* {@link #readAsBinary(InputStream)}
|
||||||
* {@link #readAsCsv(File)}
|
* {@link #readAsCsv(File)}
|
||||||
* {@link #readBinaryModel(File, boolean, boolean)}
|
* {@link #readAsCsv(InputStream)}
|
||||||
|
* {@link #readBinaryModel(InputStream, boolean, boolean)}
|
||||||
* {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)}
|
* {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)}
|
||||||
* {@link #readWord2Vec(String, boolean)}
|
* {@link #readWord2Vec(String, boolean)}
|
||||||
* {@link #readWord2Vec(File, boolean)}
|
* {@link #readWord2Vec(File, boolean)}
|
||||||
|
@ -117,6 +151,7 @@ import java.util.zip.*;
|
||||||
* {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
|
* {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
|
||||||
* {@link #fromPair(Pair)}
|
* {@link #fromPair(Pair)}
|
||||||
* {@link #loadTxt(File)}
|
* {@link #loadTxt(File)}
|
||||||
|
* {@link #loadTxt(InputStream)}
|
||||||
*
|
*
|
||||||
* <li>Serializers to tSNE format</li>
|
* <li>Serializers to tSNE format</li>
|
||||||
* {@link #writeTsneFormat(Glove, INDArray, File)}
|
* {@link #writeTsneFormat(Glove, INDArray, File)}
|
||||||
|
@ -151,6 +186,7 @@ import java.util.zip.*;
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
* @author raver119
|
* @author raver119
|
||||||
* @author alexander@skymind.io
|
* @author alexander@skymind.io
|
||||||
|
* @author Alexei KLENIN
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class WordVectorSerializer {
|
public class WordVectorSerializer {
|
||||||
|
@ -215,18 +251,22 @@ public class WordVectorSerializer {
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Read a binary word2vec file.
|
* Read a binary word2vec from input stream.
|
||||||
*
|
*
|
||||||
* @param modelFile the File to read
|
* @param inputStream input stream to read
|
||||||
* @param linebreaks if true, the reader expects each word/vector to be in a separate line, terminated
|
* @param linebreaks if true, the reader expects each word/vector to be in a separate line, terminated
|
||||||
* by a line break
|
* by a line break
|
||||||
|
* @param normalize
|
||||||
|
*
|
||||||
* @return a {@link Word2Vec model}
|
* @return a {@link Word2Vec model}
|
||||||
* @throws NumberFormatException
|
* @throws NumberFormatException
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
* @throws FileNotFoundException
|
* @throws FileNotFoundException
|
||||||
*/
|
*/
|
||||||
public static Word2Vec readBinaryModel(File modelFile, boolean linebreaks, boolean normalize)
|
public static Word2Vec readBinaryModel(
|
||||||
throws NumberFormatException, IOException {
|
InputStream inputStream,
|
||||||
|
boolean linebreaks,
|
||||||
|
boolean normalize) throws NumberFormatException, IOException {
|
||||||
InMemoryLookupTable<VocabWord> lookupTable;
|
InMemoryLookupTable<VocabWord> lookupTable;
|
||||||
VocabCache<VocabWord> cache;
|
VocabCache<VocabWord> cache;
|
||||||
INDArray syn0;
|
INDArray syn0;
|
||||||
|
@ -240,9 +280,7 @@ public class WordVectorSerializer {
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
|
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
|
||||||
|
|
||||||
try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName())
|
try (DataInputStream dis = new DataInputStream(inputStream)) {
|
||||||
? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile));
|
|
||||||
DataInputStream dis = new DataInputStream(bis)) {
|
|
||||||
words = Integer.parseInt(ReadHelper.readString(dis));
|
words = Integer.parseInt(ReadHelper.readString(dis));
|
||||||
size = Integer.parseInt(ReadHelper.readString(dis));
|
size = Integer.parseInt(ReadHelper.readString(dis));
|
||||||
syn0 = Nd4j.create(words, size);
|
syn0 = Nd4j.create(words, size);
|
||||||
|
@ -250,23 +288,26 @@ public class WordVectorSerializer {
|
||||||
|
|
||||||
printOutProjectedMemoryUse(words, size, 1);
|
printOutProjectedMemoryUse(words, size, 1);
|
||||||
|
|
||||||
lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().cache(cache)
|
lookupTable = new InMemoryLookupTable.Builder<VocabWord>()
|
||||||
.useHierarchicSoftmax(false).vectorLength(size).build();
|
.cache(cache)
|
||||||
|
.useHierarchicSoftmax(false)
|
||||||
|
.vectorLength(size)
|
||||||
|
.build();
|
||||||
|
|
||||||
int cnt = 0;
|
|
||||||
String word;
|
String word;
|
||||||
float[] vector = new float[size];
|
float[] vector = new float[size];
|
||||||
for (int i = 0; i < words; i++) {
|
for (int i = 0; i < words; i++) {
|
||||||
|
|
||||||
word = ReadHelper.readString(dis);
|
word = ReadHelper.readString(dis);
|
||||||
log.trace("Loading " + word + " with word " + i);
|
log.trace("Loading {} with word {}", word, i);
|
||||||
|
|
||||||
for (int j = 0; j < size; j++) {
|
for (int j = 0; j < size; j++) {
|
||||||
vector[j] = ReadHelper.readFloat(dis);
|
vector[j] = ReadHelper.readFloat(dis);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cache.containsWord(word))
|
if (cache.containsWord(word)) {
|
||||||
throw new ND4JIllegalStateException("Tried to add existing word. Probably time to switch linebreaks mode?");
|
throw new ND4JIllegalStateException(
|
||||||
|
"Tried to add existing word. Probably time to switch linebreaks mode?");
|
||||||
|
}
|
||||||
|
|
||||||
syn0.putRow(i, normalize ? Transforms.unitVec(Nd4j.create(vector)) : Nd4j.create(vector));
|
syn0.putRow(i, normalize ? Transforms.unitVec(Nd4j.create(vector)) : Nd4j.create(vector));
|
||||||
|
|
||||||
|
@ -285,25 +326,31 @@ public class WordVectorSerializer {
|
||||||
Nd4j.getMemoryManager().invokeGcOccasionally();
|
Nd4j.getMemoryManager().invokeGcOccasionally();
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
if (originalPeriodic)
|
if (originalPeriodic) {
|
||||||
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
||||||
|
}
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
lookupTable.setSyn0(syn0);
|
lookupTable.setSyn0(syn0);
|
||||||
|
|
||||||
|
Word2Vec ret = new Word2Vec
|
||||||
Word2Vec ret = new Word2Vec.Builder().useHierarchicSoftmax(false).resetModel(false).layerSize(syn0.columns())
|
.Builder()
|
||||||
.allowParallelTokenization(true).elementsLearningAlgorithm(new SkipGram<VocabWord>())
|
.useHierarchicSoftmax(false)
|
||||||
.learningRate(0.025).windowSize(5).workers(1).build();
|
.resetModel(false)
|
||||||
|
.layerSize(syn0.columns())
|
||||||
|
.allowParallelTokenization(true)
|
||||||
|
.elementsLearningAlgorithm(new SkipGram<VocabWord>())
|
||||||
|
.learningRate(0.025)
|
||||||
|
.windowSize(5)
|
||||||
|
.workers(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
ret.setVocab(cache);
|
ret.setVocab(cache);
|
||||||
ret.setLookupTable(lookupTable);
|
ret.setLookupTable(lookupTable);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -927,7 +974,7 @@ public class WordVectorSerializer {
|
||||||
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
|
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
|
||||||
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
|
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
|
||||||
// first we load syn0
|
// first we load syn0
|
||||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(vectors);
|
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors));
|
||||||
InMemoryLookupTable lookupTable = pair.getFirst();
|
InMemoryLookupTable lookupTable = pair.getFirst();
|
||||||
lookupTable.setNegative(configuration.getNegative());
|
lookupTable.setNegative(configuration.getNegative());
|
||||||
if (configuration.getNegative() > 0)
|
if (configuration.getNegative() > 0)
|
||||||
|
@ -1604,133 +1651,105 @@ public class WordVectorSerializer {
|
||||||
* @param vectorsFile the path of the file to load\
|
* @param vectorsFile the path of the file to load\
|
||||||
* @return
|
* @return
|
||||||
* @throws FileNotFoundException if the file does not exist
|
* @throws FileNotFoundException if the file does not exist
|
||||||
* @deprecated Use {@link #loadTxt(File)}
|
* @deprecated Use {@link #loadTxt(InputStream)}
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public static WordVectors loadTxtVectors(File vectorsFile)
|
public static WordVectors loadTxtVectors(File vectorsFile) throws IOException {
|
||||||
throws IOException {
|
FileInputStream fileInputStream = new FileInputStream(vectorsFile);
|
||||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(vectorsFile);
|
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(fileInputStream);
|
||||||
return fromPair(pair);
|
return fromPair(pair);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static InputStream fileStream(@NonNull File file) throws IOException {
|
||||||
|
boolean isZip = file.getName().endsWith(".zip");
|
||||||
|
boolean isGzip = GzipUtils.isCompressedFilename(file.getName());
|
||||||
|
|
||||||
|
InputStream inputStream;
|
||||||
|
|
||||||
|
if (isZip) {
|
||||||
|
inputStream = decompressZip(file);
|
||||||
|
} else if (isGzip) {
|
||||||
|
FileInputStream fis = new FileInputStream(file);
|
||||||
|
inputStream = new GZIPInputStream(fis);
|
||||||
|
} else {
|
||||||
|
inputStream = new FileInputStream(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new BufferedInputStream(inputStream);
|
||||||
|
}
|
||||||
|
|
||||||
private static InputStream decompressZip(File modelFile) throws IOException {
|
private static InputStream decompressZip(File modelFile) throws IOException {
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
||||||
ZipFile zipFile = new ZipFile(modelFile);
|
ZipFile zipFile = new ZipFile(modelFile);
|
||||||
InputStream inputStream = null;
|
InputStream inputStream = null;
|
||||||
|
|
||||||
try (ZipInputStream zipStream = new ZipInputStream(new BufferedInputStream(new FileInputStream(modelFile)))) {
|
try (FileInputStream fis = new FileInputStream(modelFile);
|
||||||
|
BufferedInputStream bis = new BufferedInputStream(fis);
|
||||||
ZipEntry entry = null;
|
ZipInputStream zipStream = new ZipInputStream(bis)) {
|
||||||
|
ZipEntry entry;
|
||||||
if ((entry = zipStream.getNextEntry()) != null) {
|
if ((entry = zipStream.getNextEntry()) != null) {
|
||||||
|
|
||||||
inputStream = zipFile.getInputStream(entry);
|
inputStream = zipFile.getInputStream(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (zipStream.getNextEntry() != null) {
|
if (zipStream.getNextEntry() != null) {
|
||||||
throw new RuntimeException("Zip archive " + modelFile + " contains more than 1 file");
|
throw new RuntimeException("Zip archive " + modelFile + " contains more than 1 file");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return inputStream;
|
return inputStream;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static BufferedReader createReader(File vectorsFile) throws IOException {
|
public static Pair<InMemoryLookupTable, VocabCache> loadTxt(@NonNull File file) {
|
||||||
InputStreamReader inputStreamReader;
|
try (InputStream inputStream = fileStream(file)) {
|
||||||
try {
|
return loadTxt(inputStream);
|
||||||
inputStreamReader = new InputStreamReader(decompressZip(vectorsFile));
|
} catch (IOException readTestException) {
|
||||||
} catch (IOException e) {
|
throw new RuntimeException(readTestException);
|
||||||
inputStreamReader = new InputStreamReader(GzipUtils.isCompressedFilename(vectorsFile.getName())
|
|
||||||
? new GZIPInputStream(new FileInputStream(vectorsFile))
|
|
||||||
: new FileInputStream(vectorsFile), "UTF-8");
|
|
||||||
}
|
}
|
||||||
BufferedReader reader = new BufferedReader(inputStreamReader);
|
|
||||||
return reader;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads an in memory cache from the given path (sets syn0 and the vocab)
|
* Loads an in memory cache from the given input stream (sets syn0 and the vocab).
|
||||||
*
|
*
|
||||||
* @param vectorsFile the path of the file to load
|
* @param inputStream input stream
|
||||||
* @return a Pair holding the lookup table and the vocab cache.
|
* @return a {@link Pair} holding the lookup table and the vocab cache.
|
||||||
* @throws FileNotFoundException if the input file does not exist
|
|
||||||
*/
|
*/
|
||||||
public static Pair<InMemoryLookupTable, VocabCache> loadTxt(File vectorsFile)
|
public static Pair<InMemoryLookupTable, VocabCache> loadTxt(@NonNull InputStream inputStream) {
|
||||||
throws IOException, UnsupportedEncodingException {
|
AbstractCache<VocabWord> cache = new AbstractCache<>();
|
||||||
|
LineIterator lines = null;
|
||||||
|
|
||||||
|
try (InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
|
||||||
|
BufferedReader reader = new BufferedReader(inputStreamReader)) {
|
||||||
|
lines = IOUtils.lineIterator(reader);
|
||||||
|
|
||||||
AbstractCache cache = new AbstractCache<>();
|
|
||||||
BufferedReader reader = createReader(vectorsFile);
|
|
||||||
LineIterator iter = IOUtils.lineIterator(reader);
|
|
||||||
String line = null;
|
String line = null;
|
||||||
boolean hasHeader = false;
|
boolean hasHeader = false;
|
||||||
if (iter.hasNext()) {
|
|
||||||
line = iter.nextLine(); // skip header line
|
|
||||||
//look for spaces
|
|
||||||
if (!line.contains(" ")) {
|
|
||||||
log.debug("Skipping first line");
|
|
||||||
hasHeader = true;
|
|
||||||
} else {
|
|
||||||
// we should check for something that looks like proper word vectors here. i.e: 1 word at the 0 position, and bunch of floats further
|
|
||||||
String[] split = line.split(" ");
|
|
||||||
try {
|
|
||||||
long[] header = new long[split.length];
|
|
||||||
for (int x = 0; x < split.length; x++) {
|
|
||||||
header[x] = Long.parseLong(split[x]);
|
|
||||||
}
|
|
||||||
if (split.length < 4)
|
|
||||||
hasHeader = true;
|
|
||||||
// now we know, if that's all ints - it's just a header
|
|
||||||
// [0] - number of words
|
|
||||||
// [1] - vectorSize
|
|
||||||
// [2] - number of documents <-- DL4j-only value
|
|
||||||
if (split.length == 3)
|
|
||||||
cache.incrementTotalDocCount(header[2]);
|
|
||||||
|
|
||||||
printOutProjectedMemoryUse(header[0], (int) header[1], 1);
|
/* Check if first line is a header */
|
||||||
|
if (lines.hasNext()) {
|
||||||
hasHeader = true;
|
line = lines.nextLine();
|
||||||
|
hasHeader = isHeader(line, cache);
|
||||||
try {
|
|
||||||
reader.close();
|
|
||||||
} catch (Exception ex) {
|
|
||||||
}
|
|
||||||
} catch (Exception e) {
|
|
||||||
// if any conversion exception hits - that'll be considered header
|
|
||||||
hasHeader = false;
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
//reposition buffer to be one line ahead
|
|
||||||
if (hasHeader) {
|
if (hasHeader) {
|
||||||
line = "";
|
log.debug("First line is a header");
|
||||||
iter.close();
|
line = lines.nextLine();
|
||||||
//reader = new BufferedReader(new FileReader(vectorsFile));
|
|
||||||
reader = createReader(vectorsFile);
|
|
||||||
iter = IOUtils.lineIterator(reader);
|
|
||||||
iter.nextLine();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
List<INDArray> arrays = new ArrayList<>();
|
List<INDArray> arrays = new ArrayList<>();
|
||||||
long[] vShape = new long[]{1, -1};
|
long[] vShape = new long[]{ 1, -1 };
|
||||||
while (iter.hasNext()) {
|
|
||||||
if (line.isEmpty())
|
|
||||||
line = iter.nextLine();
|
|
||||||
String[] split = line.split(" ");
|
|
||||||
String word = ReadHelper.decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
|
|
||||||
VocabWord word1 = new VocabWord(1.0, word);
|
|
||||||
|
|
||||||
word1.setIndex(cache.numWords());
|
do {
|
||||||
|
String[] tokens = line.split(" ");
|
||||||
cache.addToken(word1);
|
String word = ReadHelper.decodeB64(tokens[0]);
|
||||||
|
VocabWord vocabWord = new VocabWord(1.0, word);
|
||||||
cache.addWordToIndex(word1.getIndex(), word);
|
vocabWord.setIndex(cache.numWords());
|
||||||
|
|
||||||
|
cache.addToken(vocabWord);
|
||||||
|
cache.addWordToIndex(vocabWord.getIndex(), word);
|
||||||
cache.putVocabWord(word);
|
cache.putVocabWord(word);
|
||||||
|
|
||||||
float[] vector = new float[split.length - 1];
|
float[] vector = new float[tokens.length - 1];
|
||||||
|
for (int i = 1; i < tokens.length; i++) {
|
||||||
for (int i = 1; i < split.length; i++) {
|
vector[i - 1] = Float.parseFloat(tokens[i]);
|
||||||
vector[i - 1] = Float.parseFloat(split[i]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vShape[1] = vector.length;
|
vShape[1] = vector.length;
|
||||||
|
@ -1738,26 +1757,66 @@ public class WordVectorSerializer {
|
||||||
|
|
||||||
arrays.add(row);
|
arrays.add(row);
|
||||||
|
|
||||||
// workaround for skipped first row
|
line = lines.hasNext() ? lines.next() : null;
|
||||||
line = "";
|
} while (line != null);
|
||||||
}
|
|
||||||
|
|
||||||
INDArray syn = Nd4j.vstack(arrays);
|
INDArray syn = Nd4j.vstack(arrays);
|
||||||
|
|
||||||
InMemoryLookupTable lookupTable =
|
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
|
||||||
(InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(arrays.get(0).columns())
|
.Builder<VocabWord>()
|
||||||
.useAdaGrad(false).cache(cache).useHierarchicSoftmax(false).build();
|
.vectorLength(arrays.get(0).columns())
|
||||||
|
.useAdaGrad(false)
|
||||||
|
.cache(cache)
|
||||||
|
.useHierarchicSoftmax(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
lookupTable.setSyn0(syn);
|
lookupTable.setSyn0(syn);
|
||||||
|
|
||||||
iter.close();
|
return new Pair<>((InMemoryLookupTable) lookupTable, (VocabCache) cache);
|
||||||
|
} catch (IOException readeTextStreamException) {
|
||||||
try {
|
throw new RuntimeException(readeTextStreamException);
|
||||||
reader.close();
|
} finally {
|
||||||
} catch (Exception e) {
|
if (lines != null) {
|
||||||
|
lines.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Pair<>(lookupTable, (VocabCache) cache);
|
static boolean isHeader(String line, AbstractCache cache) {
|
||||||
|
if (!line.contains(" ")) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
|
||||||
|
/* We should check for something that looks like proper word vectors here. i.e: 1 word at the 0
|
||||||
|
* position, and bunch of floats further */
|
||||||
|
String[] headers = line.split(" ");
|
||||||
|
|
||||||
|
try {
|
||||||
|
long[] header = new long[headers.length];
|
||||||
|
for (int x = 0; x < headers.length; x++) {
|
||||||
|
header[x] = Long.parseLong(headers[x]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Now we know, if that's all ints - it's just a header
|
||||||
|
* [0] - number of words
|
||||||
|
* [1] - vectorLength
|
||||||
|
* [2] - number of documents <-- DL4j-only value
|
||||||
|
*/
|
||||||
|
if (headers.length == 3) {
|
||||||
|
long numberOfDocuments = header[2];
|
||||||
|
cache.incrementTotalDocCount(numberOfDocuments);
|
||||||
|
}
|
||||||
|
|
||||||
|
long numWords = header[0];
|
||||||
|
int vectorLength = (int) header[1];
|
||||||
|
printOutProjectedMemoryUse(numWords, vectorLength, 1);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
} catch (Exception notHeaderException) {
|
||||||
|
// if any conversion exception hits - that'll be considered header
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -2352,22 +2411,6 @@ public class WordVectorSerializer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* This method
|
|
||||||
* 1) Binary model, either compressed or not. Like well-known Google Model
|
|
||||||
* 2) Popular CSV word2vec text format
|
|
||||||
* 3) DL4j compressed format
|
|
||||||
* <p>
|
|
||||||
* Please note: Only weights will be loaded by this method.
|
|
||||||
*
|
|
||||||
* @param file
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static Word2Vec readWord2VecModel(@NonNull File file) {
|
|
||||||
return readWord2VecModel(file, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method
|
* This method
|
||||||
* 1) Binary model, either compressed or not. Like well-known Google Model
|
* 1) Binary model, either compressed or not. Like well-known Google Model
|
||||||
|
@ -2389,9 +2432,9 @@ public class WordVectorSerializer {
|
||||||
* 2) Popular CSV word2vec text format
|
* 2) Popular CSV word2vec text format
|
||||||
* 3) DL4j compressed format
|
* 3) DL4j compressed format
|
||||||
* <p>
|
* <p>
|
||||||
* Please note: if extended data isn't available, only weights will be loaded instead.
|
* Please note: Only weights will be loaded by this method.
|
||||||
*
|
*
|
||||||
* @param path
|
* @param path path to model file
|
||||||
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
|
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@ -2399,96 +2442,186 @@ public class WordVectorSerializer {
|
||||||
return readWord2VecModel(new File(path), extendedModel);
|
return readWord2VecModel(new File(path), extendedModel);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Word2Vec readAsBinaryNoLineBreaks(@NonNull File file) {
|
/**
|
||||||
|
* This method
|
||||||
|
* 1) Binary model, either compressed or not. Like well-known Google Model
|
||||||
|
* 2) Popular CSV word2vec text format
|
||||||
|
* 3) DL4j compressed format
|
||||||
|
* <p>
|
||||||
|
* Please note: Only weights will be loaded by this method.
|
||||||
|
*
|
||||||
|
* @param file
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Word2Vec readWord2VecModel(File file) {
|
||||||
|
return readWord2VecModel(file, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method
|
||||||
|
* 1) Binary model, either compressed or not. Like well-known Google Model
|
||||||
|
* 2) Popular CSV word2vec text format
|
||||||
|
* 3) DL4j compressed format
|
||||||
|
* <p>
|
||||||
|
* Please note: if extended data isn't available, only weights will be loaded instead.
|
||||||
|
*
|
||||||
|
* @param file model file
|
||||||
|
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
|
||||||
|
* @return word2vec model
|
||||||
|
*/
|
||||||
|
public static Word2Vec readWord2VecModel(File file, boolean extendedModel) {
|
||||||
|
if (!file.exists() || !file.isFile()) {
|
||||||
|
throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
||||||
|
if (originalPeriodic) {
|
||||||
|
Nd4j.getMemoryManager().togglePeriodicGc(false);
|
||||||
|
}
|
||||||
|
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
|
||||||
|
|
||||||
|
try {
|
||||||
|
return readWord2Vec(file, extendedModel);
|
||||||
|
} catch (Exception readSequenceVectors) {
|
||||||
|
try {
|
||||||
|
return extendedModel
|
||||||
|
? readAsExtendedModel(file)
|
||||||
|
: readAsSimplifiedModel(file);
|
||||||
|
} catch (Exception loadFromFileException) {
|
||||||
|
try {
|
||||||
|
return readAsCsv(file);
|
||||||
|
} catch (Exception readCsvException) {
|
||||||
|
try {
|
||||||
|
return readAsBinary(file);
|
||||||
|
} catch (Exception readBinaryException) {
|
||||||
|
try {
|
||||||
|
return readAsBinaryNoLineBreaks(file);
|
||||||
|
} catch (Exception readModelException) {
|
||||||
|
log.error("Unable to guess input file format", readModelException);
|
||||||
|
throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Word2Vec readAsBinaryNoLineBreaks(@NonNull File file) {
|
||||||
|
try (InputStream inputStream = fileStream(file)) {
|
||||||
|
return readAsBinaryNoLineBreaks(inputStream);
|
||||||
|
} catch (IOException readCsvException) {
|
||||||
|
throw new RuntimeException(readCsvException);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Word2Vec readAsBinaryNoLineBreaks(@NonNull InputStream inputStream) {
|
||||||
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
||||||
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
||||||
Word2Vec vec;
|
|
||||||
|
|
||||||
// try to load without linebreaks
|
// try to load without linebreaks
|
||||||
try {
|
try {
|
||||||
if (originalPeriodic)
|
if (originalPeriodic) {
|
||||||
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
||||||
|
}
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
||||||
|
|
||||||
vec = readBinaryModel(file, false, false);
|
return readBinaryModel(inputStream, false, false);
|
||||||
return vec;
|
} catch (Exception readModelException) {
|
||||||
} catch (Exception ez) {
|
log.error("Cannot read binary model", readModelException);
|
||||||
throw new RuntimeException(
|
throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
|
||||||
"Unable to guess input file format. Please use corresponding loader directly");
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Word2Vec readAsBinary(@NonNull File file) {
|
||||||
|
try (InputStream inputStream = fileStream(file)) {
|
||||||
|
return readAsBinary(inputStream);
|
||||||
|
} catch (IOException readCsvException) {
|
||||||
|
throw new RuntimeException(readCsvException);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method loads Word2Vec model from binary file
|
* This method loads Word2Vec model from binary input stream.
|
||||||
*
|
*
|
||||||
* @param file File
|
* @param inputStream binary input stream
|
||||||
* @return Word2Vec
|
* @return Word2Vec
|
||||||
*/
|
*/
|
||||||
public static Word2Vec readAsBinary(@NonNull File file) {
|
public static Word2Vec readAsBinary(@NonNull InputStream inputStream) {
|
||||||
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
||||||
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
||||||
|
|
||||||
Word2Vec vec;
|
|
||||||
|
|
||||||
// we fallback to trying binary model instead
|
// we fallback to trying binary model instead
|
||||||
try {
|
try {
|
||||||
log.debug("Trying binary model restoration...");
|
log.debug("Trying binary model restoration...");
|
||||||
|
|
||||||
if (originalPeriodic)
|
if (originalPeriodic) {
|
||||||
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
||||||
|
}
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
||||||
|
|
||||||
vec = readBinaryModel(file, true, false);
|
return readBinaryModel(inputStream, true, false);
|
||||||
return vec;
|
} catch (Exception readModelException) {
|
||||||
} catch (Exception ey) {
|
throw new RuntimeException(readModelException);
|
||||||
throw new RuntimeException(ey);
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Word2Vec readAsCsv(@NonNull File file) {
|
||||||
|
try (InputStream inputStream = fileStream(file)) {
|
||||||
|
return readAsCsv(inputStream);
|
||||||
|
} catch (IOException readCsvException) {
|
||||||
|
throw new RuntimeException(readCsvException);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method loads Word2Vec model from csv file
|
* This method loads Word2Vec model from csv file
|
||||||
*
|
*
|
||||||
* @param file File
|
* @param inputStream input stream
|
||||||
* @return Word2Vec
|
* @return Word2Vec model
|
||||||
*/
|
*/
|
||||||
public static Word2Vec readAsCsv(@NonNull File file) {
|
public static Word2Vec readAsCsv(@NonNull InputStream inputStream) {
|
||||||
|
|
||||||
Word2Vec vec;
|
|
||||||
VectorsConfiguration configuration = new VectorsConfiguration();
|
VectorsConfiguration configuration = new VectorsConfiguration();
|
||||||
|
|
||||||
// let's try to load this file as csv file
|
// let's try to load this file as csv file
|
||||||
try {
|
try {
|
||||||
log.debug("Trying CSV model restoration...");
|
log.debug("Trying CSV model restoration...");
|
||||||
|
|
||||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(file);
|
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(inputStream);
|
||||||
Word2Vec.Builder builder = new Word2Vec.Builder().lookupTable(pair.getFirst()).useAdaGrad(false)
|
Word2Vec.Builder builder = new Word2Vec
|
||||||
.vocabCache(pair.getSecond()).layerSize(pair.getFirst().layerSize())
|
.Builder()
|
||||||
|
.lookupTable(pair.getFirst())
|
||||||
|
.useAdaGrad(false)
|
||||||
|
.vocabCache(pair.getSecond())
|
||||||
|
.layerSize(pair.getFirst().layerSize())
|
||||||
// we don't use hs here, because model is incomplete
|
// we don't use hs here, because model is incomplete
|
||||||
.useHierarchicSoftmax(false).resetModel(false);
|
.useHierarchicSoftmax(false)
|
||||||
|
.resetModel(false);
|
||||||
|
|
||||||
TokenizerFactory factory = getTokenizerFactory(configuration);
|
TokenizerFactory factory = getTokenizerFactory(configuration);
|
||||||
if (factory != null)
|
if (factory != null) {
|
||||||
builder.tokenizerFactory(factory);
|
builder.tokenizerFactory(factory);
|
||||||
|
}
|
||||||
|
|
||||||
vec = builder.build();
|
return builder.build();
|
||||||
return vec;
|
|
||||||
} catch (Exception ex) {
|
} catch (Exception ex) {
|
||||||
throw new RuntimeException("Unable to load model in CSV format");
|
throw new RuntimeException("Unable to load model in CSV format");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method just loads full compressed model.
|
||||||
|
*/
|
||||||
private static Word2Vec readAsExtendedModel(@NonNull File file) throws IOException {
|
private static Word2Vec readAsExtendedModel(@NonNull File file) throws IOException {
|
||||||
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
||||||
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
||||||
|
|
||||||
log.debug("Trying full model restoration...");
|
log.debug("Trying full model restoration...");
|
||||||
// this method just loads full compressed model
|
|
||||||
|
|
||||||
if (originalPeriodic)
|
if (originalPeriodic) {
|
||||||
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
||||||
|
}
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
||||||
|
|
||||||
|
@ -2627,67 +2760,6 @@ public class WordVectorSerializer {
|
||||||
return vec;
|
return vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* This method
|
|
||||||
* 1) Binary model, either compressed or not. Like well-known Google Model
|
|
||||||
* 2) Popular CSV word2vec text format
|
|
||||||
* 3) DL4j compressed format
|
|
||||||
* <p>
|
|
||||||
* Please note: if extended data isn't available, only weights will be loaded instead.
|
|
||||||
*
|
|
||||||
* @param file
|
|
||||||
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedModel) {
|
|
||||||
|
|
||||||
if (!file.exists() || !file.isFile())
|
|
||||||
throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
|
|
||||||
|
|
||||||
Word2Vec vec = null;
|
|
||||||
|
|
||||||
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
|
||||||
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
|
||||||
if (originalPeriodic)
|
|
||||||
Nd4j.getMemoryManager().togglePeriodicGc(false);
|
|
||||||
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
|
|
||||||
|
|
||||||
// try to load zip format
|
|
||||||
try {
|
|
||||||
vec = readWord2Vec(file, extendedModel);
|
|
||||||
return vec;
|
|
||||||
} catch (Exception e) {
|
|
||||||
// let's try to load this file as csv file
|
|
||||||
try {
|
|
||||||
if (extendedModel) {
|
|
||||||
vec = readAsExtendedModel(file);
|
|
||||||
return vec;
|
|
||||||
} else {
|
|
||||||
vec = readAsSimplifiedModel(file);
|
|
||||||
return vec;
|
|
||||||
}
|
|
||||||
} catch (Exception ex) {
|
|
||||||
try {
|
|
||||||
vec = readAsCsv(file);
|
|
||||||
return vec;
|
|
||||||
} catch (Exception exc) {
|
|
||||||
try {
|
|
||||||
vec = readAsBinary(file);
|
|
||||||
return vec;
|
|
||||||
} catch (Exception exce) {
|
|
||||||
try {
|
|
||||||
vec = readAsBinaryNoLineBreaks(file);
|
|
||||||
return vec;
|
|
||||||
|
|
||||||
} catch (Exception excep) {
|
|
||||||
throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected static TokenizerFactory getTokenizerFactory(VectorsConfiguration configuration) {
|
protected static TokenizerFactory getTokenizerFactory(VectorsConfiguration configuration) {
|
||||||
if (configuration == null)
|
if (configuration == null)
|
||||||
return null;
|
return null;
|
||||||
|
@ -3019,16 +3091,13 @@ public class WordVectorSerializer {
|
||||||
/**
|
/**
|
||||||
* This method restores Word2Vec model from file
|
* This method restores Word2Vec model from file
|
||||||
*
|
*
|
||||||
* @param path String
|
* @param path
|
||||||
* @param readExtendedTables booleab
|
* @param readExtendedTables
|
||||||
* @return Word2Vec
|
* @return Word2Vec
|
||||||
*/
|
*/
|
||||||
public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables)
|
public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables) {
|
||||||
throws IOException {
|
|
||||||
|
|
||||||
File file = new File(path);
|
File file = new File(path);
|
||||||
Word2Vec word2Vec = readWord2Vec(file, readExtendedTables);
|
return readWord2Vec(file, readExtendedTables);
|
||||||
return word2Vec;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -3139,11 +3208,12 @@ public class WordVectorSerializer {
|
||||||
* @param readExtendedTables boolean
|
* @param readExtendedTables boolean
|
||||||
* @return Word2Vec
|
* @return Word2Vec
|
||||||
*/
|
*/
|
||||||
public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables)
|
public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables) {
|
||||||
throws IOException {
|
try (InputStream inputStream = fileStream(file)) {
|
||||||
|
return readWord2Vec(inputStream, readExtendedTables);
|
||||||
Word2Vec word2Vec = readWord2Vec(new FileInputStream(file), readExtendedTables);
|
} catch (Exception readSequenceVectors) {
|
||||||
return word2Vec;
|
throw new RuntimeException(readSequenceVectors);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -3153,13 +3223,19 @@ public class WordVectorSerializer {
|
||||||
* @param readExtendedTable boolean
|
* @param readExtendedTable boolean
|
||||||
* @return Word2Vec
|
* @return Word2Vec
|
||||||
*/
|
*/
|
||||||
public static Word2Vec readWord2Vec(@NonNull InputStream stream,
|
public static Word2Vec readWord2Vec(
|
||||||
|
@NonNull InputStream stream,
|
||||||
boolean readExtendedTable) throws IOException {
|
boolean readExtendedTable) throws IOException {
|
||||||
SequenceVectors<VocabWord> vectors = readSequenceVectors(stream, readExtendedTable);
|
SequenceVectors<VocabWord> vectors = readSequenceVectors(stream, readExtendedTable);
|
||||||
Word2Vec word2Vec = new Word2Vec.Builder(vectors.getConfiguration()).layerSize(vectors.getLayerSize()).build();
|
|
||||||
|
Word2Vec word2Vec = new Word2Vec
|
||||||
|
.Builder(vectors.getConfiguration())
|
||||||
|
.layerSize(vectors.getLayerSize())
|
||||||
|
.build();
|
||||||
word2Vec.setVocab(vectors.getVocab());
|
word2Vec.setVocab(vectors.getVocab());
|
||||||
word2Vec.setLookupTable(vectors.lookupTable());
|
word2Vec.setLookupTable(vectors.lookupTable());
|
||||||
word2Vec.setModelUtils(vectors.getModelUtils());
|
word2Vec.setModelUtils(vectors.getModelUtils());
|
||||||
|
|
||||||
return word2Vec;
|
return word2Vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,8 +37,6 @@ import java.io.File;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TsneTest extends BaseDL4JTest {
|
public class TsneTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
|
@ -14,17 +14,14 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.models.sequencevectors.serialization;
|
package org.deeplearning4j.models.embeddings.loader;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.lang.StringUtils;
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
|
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
|
||||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
|
||||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
|
||||||
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
||||||
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
|
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
|
||||||
import org.deeplearning4j.models.fasttext.FastText;
|
import org.deeplearning4j.models.fasttext.FastText;
|
||||||
|
@ -47,7 +44,11 @@ import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertNotNull;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class WordVectorSerializerTest extends BaseDL4JTest {
|
public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
|
@ -78,9 +79,10 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
||||||
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
||||||
|
|
||||||
InMemoryLookupTable<VocabWord> lookupTable =
|
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
|
||||||
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
|
.Builder<VocabWord>()
|
||||||
.useAdaGrad(false).cache(cache)
|
.useAdaGrad(false)
|
||||||
|
.cache(cache)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
lookupTable.setSyn0(syn0);
|
lookupTable.setSyn0(syn0);
|
||||||
|
@ -92,7 +94,6 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
lookupTable(lookupTable).
|
lookupTable(lookupTable).
|
||||||
build();
|
build();
|
||||||
SequenceVectors<VocabWord> deser = null;
|
SequenceVectors<VocabWord> deser = null;
|
||||||
String json = StringUtils.EMPTY;
|
|
||||||
try {
|
try {
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
WordVectorSerializer.writeSequenceVectors(vectors, baos);
|
WordVectorSerializer.writeSequenceVectors(vectors, baos);
|
||||||
|
@ -126,9 +127,10 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
||||||
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
||||||
|
|
||||||
InMemoryLookupTable<VocabWord> lookupTable =
|
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
|
||||||
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
|
.Builder<VocabWord>()
|
||||||
.useAdaGrad(false).cache(cache)
|
.useAdaGrad(false)
|
||||||
|
.cache(cache)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
lookupTable.setSyn0(syn0);
|
lookupTable.setSyn0(syn0);
|
||||||
|
@ -204,9 +206,10 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
||||||
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
||||||
|
|
||||||
InMemoryLookupTable<VocabWord> lookupTable =
|
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
|
||||||
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
|
.Builder<VocabWord>()
|
||||||
.useAdaGrad(false).cache(cache)
|
.useAdaGrad(false)
|
||||||
|
.cache(cache)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
lookupTable.setSyn0(syn0);
|
lookupTable.setSyn0(syn0);
|
||||||
|
@ -252,9 +255,10 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
||||||
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
||||||
|
|
||||||
InMemoryLookupTable<VocabWord> lookupTable =
|
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
|
||||||
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
|
.Builder<VocabWord>()
|
||||||
.useAdaGrad(false).cache(cache)
|
.useAdaGrad(false)
|
||||||
|
.cache(cache)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
lookupTable.setSyn0(syn0);
|
lookupTable.setSyn0(syn0);
|
||||||
|
@ -267,7 +271,6 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
WeightLookupTable<VocabWord> deser = null;
|
WeightLookupTable<VocabWord> deser = null;
|
||||||
try {
|
try {
|
||||||
WordVectorSerializer.writeLookupTable(lookupTable, file);
|
WordVectorSerializer.writeLookupTable(lookupTable, file);
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
||||||
deser = WordVectorSerializer.readLookupTable(file);
|
deser = WordVectorSerializer.readLookupTable(file);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("",e);
|
log.error("",e);
|
||||||
|
@ -305,7 +308,6 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
FastText deser = null;
|
FastText deser = null;
|
||||||
try {
|
try {
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
||||||
deser = WordVectorSerializer.readWordVectors(new File(dir, "some.data"));
|
deser = WordVectorSerializer.readWordVectors(new File(dir, "some.data"));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("",e);
|
log.error("",e);
|
||||||
|
@ -323,4 +325,32 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
assertEquals(fastText.getInputFile(), deser.getInputFile());
|
assertEquals(fastText.getInputFile(), deser.getInputFile());
|
||||||
assertEquals(fastText.getOutputFile(), deser.getOutputFile());
|
assertEquals(fastText.getOutputFile(), deser.getOutputFile());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIsHeader_withValidHeader () {
|
||||||
|
|
||||||
|
/* Given */
|
||||||
|
AbstractCache<VocabWord> cache = new AbstractCache<>();
|
||||||
|
String line = "48 100";
|
||||||
|
|
||||||
|
/* When */
|
||||||
|
boolean isHeader = WordVectorSerializer.isHeader(line, cache);
|
||||||
|
|
||||||
|
/* Then */
|
||||||
|
assertTrue(isHeader);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIsHeader_notHeader () {
|
||||||
|
|
||||||
|
/* Given */
|
||||||
|
AbstractCache<VocabWord> cache = new AbstractCache<>();
|
||||||
|
String line = "your -0.0017603 0.0030831 0.00069072 0.0020581 -0.0050952 -2.2573e-05 -0.001141";
|
||||||
|
|
||||||
|
/* When */
|
||||||
|
boolean isHeader = WordVectorSerializer.isHeader(line, cache);
|
||||||
|
|
||||||
|
/* Then */
|
||||||
|
assertFalse(isHeader);
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -1,9 +1,9 @@
|
||||||
package org.deeplearning4j.models.fasttext;
|
package org.deeplearning4j.models.fasttext;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||||
import org.deeplearning4j.models.word2vec.Word2Vec;
|
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
|
||||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
|
@ -14,13 +14,14 @@ import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.common.resources.Resources;
|
import org.nd4j.common.resources.Resources;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import static org.hamcrest.CoreMatchers.hasItems;
|
||||||
|
import static org.hamcrest.MatcherAssert.assertThat;
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FastTextTest extends BaseDL4JTest {
|
public class FastTextTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@ -32,7 +33,6 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin");
|
private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin");
|
||||||
private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec");
|
private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec");
|
||||||
|
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void tesLoadCBOWModel() throws IOException {
|
public void tesLoadCBOWModel() {
|
||||||
|
|
||||||
FastText fastText = new FastText(cbowModelFile);
|
FastText fastText = new FastText(cbowModelFile);
|
||||||
fastText.test(cbowModelFile);
|
fastText.test(cbowModelFile);
|
||||||
|
@ -99,7 +99,7 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
assertEquals("enjoy", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
assertEquals("enjoy", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
||||||
|
|
||||||
double[] expected = {5.040466203354299E-4, 0.001005030469968915, 2.8882650076411664E-4, -6.413314840756357E-4, -1.78931062691845E-4, -0.0023157168179750443, -0.002215880434960127, 0.00274421414360404, -1.5344757412094623E-4, 4.6274057240225375E-4, -1.4383681991603225E-4, 3.7832374800927937E-4, 2.523412986192852E-4, 0.0018913350068032742, -0.0024741862434893847, -4.976555937901139E-4, 0.0039220210164785385, -0.001781729981303215, -6.010578363202512E-4, -0.00244093406945467, -7.98621098510921E-4, -0.0010007203090935946, -0.001640203408896923, 7.897148607298732E-4, 9.131592814810574E-4, -0.0013367272913455963, -0.0014030139427632093, -7.755287806503475E-4, -4.2878396925516427E-4, 6.912827957421541E-4, -0.0011824817629531026, -0.0036014916840940714, 0.004353308118879795, -7.073904271237552E-5, -9.646290563978255E-4, -0.0031849315855652094, 2.3360115301329643E-4, -2.9103990527801216E-4, -0.0022990566212683916, -0.002393763978034258, -0.001034979010000825, -0.0010725988540798426, 0.0018285386031493545, -0.0013178540393710136, -1.6632364713586867E-4, -1.4665909475297667E-5, 5.445032729767263E-4, 2.999933494720608E-4, -0.0014367225812748075, -0.002345481887459755, 0.001117417006753385, -8.688368834555149E-4, -0.001830018823966384, 0.0013242220738902688, -8.880519890226424E-4, -6.888324278406799E-4, -0.0036394784692674875, 0.002179111586883664, -1.7201311129610986E-4, 0.002365073887631297, 0.002688770182430744, 0.0023955567739903927, 0.001469283364713192, 0.0011803617235273123, 5.871498142369092E-4, -7.099180947989225E-4, 7.518937345594168E-4, -8.599072461947799E-4, -6.600041524507105E-4, -0.002724145073443651, -8.365285466425121E-4, 0.0013173354091122746, 0.001083166105672717, 0.0014539906987920403, -3.1698777456767857E-4, -2.387022686889395E-4, 1.9560157670639455E-4, 0.0020277926232665777, -0.0012741144746541977, -0.0013026101514697075, -1.5212174912448972E-4, 0.0014194383984431624, 0.0012500399025157094, 0.0013362085446715355, 3.692879108712077E-4, 4.319801155361347E-5, 0.0011261265026405454, 0.0017244465416297317, 5.564604725805111E-5, 0.002170475199818611, 0.0014707016525790095, 0.001303741242736578, 0.005553730763494968, -0.0011097051901742816, -0.0013661726843565702, 0.0014100460102781653, 0.0011811562580987811, -6.622733199037611E-4, 7.860265322960913E-4, -9.811905911192298E-4};
|
double[] expected = {5.040466203354299E-4, 0.001005030469968915, 2.8882650076411664E-4, -6.413314840756357E-4, -1.78931062691845E-4, -0.0023157168179750443, -0.002215880434960127, 0.00274421414360404, -1.5344757412094623E-4, 4.6274057240225375E-4, -1.4383681991603225E-4, 3.7832374800927937E-4, 2.523412986192852E-4, 0.0018913350068032742, -0.0024741862434893847, -4.976555937901139E-4, 0.0039220210164785385, -0.001781729981303215, -6.010578363202512E-4, -0.00244093406945467, -7.98621098510921E-4, -0.0010007203090935946, -0.001640203408896923, 7.897148607298732E-4, 9.131592814810574E-4, -0.0013367272913455963, -0.0014030139427632093, -7.755287806503475E-4, -4.2878396925516427E-4, 6.912827957421541E-4, -0.0011824817629531026, -0.0036014916840940714, 0.004353308118879795, -7.073904271237552E-5, -9.646290563978255E-4, -0.0031849315855652094, 2.3360115301329643E-4, -2.9103990527801216E-4, -0.0022990566212683916, -0.002393763978034258, -0.001034979010000825, -0.0010725988540798426, 0.0018285386031493545, -0.0013178540393710136, -1.6632364713586867E-4, -1.4665909475297667E-5, 5.445032729767263E-4, 2.999933494720608E-4, -0.0014367225812748075, -0.002345481887459755, 0.001117417006753385, -8.688368834555149E-4, -0.001830018823966384, 0.0013242220738902688, -8.880519890226424E-4, -6.888324278406799E-4, -0.0036394784692674875, 0.002179111586883664, -1.7201311129610986E-4, 0.002365073887631297, 0.002688770182430744, 0.0023955567739903927, 0.001469283364713192, 0.0011803617235273123, 5.871498142369092E-4, -7.099180947989225E-4, 7.518937345594168E-4, -8.599072461947799E-4, -6.600041524507105E-4, -0.002724145073443651, -8.365285466425121E-4, 0.0013173354091122746, 0.001083166105672717, 0.0014539906987920403, -3.1698777456767857E-4, -2.387022686889395E-4, 1.9560157670639455E-4, 0.0020277926232665777, -0.0012741144746541977, -0.0013026101514697075, -1.5212174912448972E-4, 0.0014194383984431624, 0.0012500399025157094, 0.0013362085446715355, 3.692879108712077E-4, 4.319801155361347E-5, 0.0011261265026405454, 0.0017244465416297317, 5.564604725805111E-5, 0.002170475199818611, 0.0014707016525790095, 0.001303741242736578, 0.005553730763494968, -0.0011097051901742816, -0.0013661726843565702, 0.0014100460102781653, 0.0011811562580987811, -6.622733199037611E-4, 7.860265322960913E-4, -9.811905911192298E-4};
|
||||||
assertArrayEquals(expected, fastText.getWordVector("enjoy"), 1e-4);
|
assertArrayEquals(expected, fastText.getWordVector("enjoy"), 2e-3);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -111,7 +111,7 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
||||||
|
|
||||||
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
|
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
|
||||||
assertArrayEquals(expected, fastText.getWordVector("association"), 1e-4);
|
assertArrayEquals(expected, fastText.getWordVector("association"), 2e-3);
|
||||||
|
|
||||||
String label = fastText.predict(text);
|
String label = fastText.predict(text);
|
||||||
assertEquals("__label__soccer", label);
|
assertEquals("__label__soccer", label);
|
||||||
|
@ -126,7 +126,7 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
||||||
|
|
||||||
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
|
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
|
||||||
assertArrayEquals(expected, fastText.getWordVector("association"), 1e-4);
|
assertArrayEquals(expected, fastText.getWordVector("association"), 2e-3);
|
||||||
|
|
||||||
String label = fastText.predict(text);
|
String label = fastText.predict(text);
|
||||||
fastText.wordsNearest("test",1);
|
fastText.wordsNearest("test",1);
|
||||||
|
@ -140,10 +140,10 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
|
|
||||||
Pair<String,Float> result = fastText.predictProbability(text);
|
Pair<String,Float> result = fastText.predictProbability(text);
|
||||||
assertEquals("__label__soccer", result.getFirst());
|
assertEquals("__label__soccer", result.getFirst());
|
||||||
assertEquals(-0.6930, result.getSecond(), 1e-4);
|
assertEquals(-0.6930, result.getSecond(), 2e-3);
|
||||||
|
|
||||||
assertEquals(48, fastText.vocabSize());
|
assertEquals(48, fastText.vocabSize());
|
||||||
assertEquals(0.0500, fastText.getLearningRate(), 1e-4);
|
assertEquals(0.0500, fastText.getLearningRate(), 2e-3);
|
||||||
assertEquals(100, fastText.getDimension());
|
assertEquals(100, fastText.getDimension());
|
||||||
assertEquals(5, fastText.getContextWindowSize());
|
assertEquals(5, fastText.getContextWindowSize());
|
||||||
assertEquals(5, fastText.getEpoch());
|
assertEquals(5, fastText.getEpoch());
|
||||||
|
@ -155,7 +155,7 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVocabulary() throws IOException {
|
public void testVocabulary() {
|
||||||
FastText fastText = new FastText(supModelFile);
|
FastText fastText = new FastText(supModelFile);
|
||||||
assertEquals(48, fastText.vocab().numWords());
|
assertEquals(48, fastText.vocab().numWords());
|
||||||
assertEquals(48, fastText.vocabSize());
|
assertEquals(48, fastText.vocabSize());
|
||||||
|
@ -171,78 +171,73 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLoadIterator() {
|
public void testLoadIterator() throws FileNotFoundException {
|
||||||
try {
|
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
FastText fastText =
|
FastText
|
||||||
FastText.builder().supervised(true).iterator(iter).build();
|
.builder()
|
||||||
fastText.loadIterator();
|
.supervised(true)
|
||||||
|
.iterator(iter)
|
||||||
} catch (IOException e) {
|
.build()
|
||||||
log.error("",e);
|
.loadIterator();
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected=IllegalStateException.class)
|
@Test(expected=IllegalStateException.class)
|
||||||
public void testState() {
|
public void testState() {
|
||||||
FastText fastText = new FastText();
|
FastText fastText = new FastText();
|
||||||
String label = fastText.predict("something");
|
fastText.predict("something");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPretrainedVectors() throws IOException {
|
public void testPretrainedVectors() throws IOException {
|
||||||
File output = testDir.newFile();
|
File output = testDir.newFile();
|
||||||
|
|
||||||
FastText fastText =
|
FastText fastText = FastText
|
||||||
FastText.builder().supervised(true).
|
.builder()
|
||||||
inputFile(inputFile.getAbsolutePath()).
|
.supervised(true)
|
||||||
pretrainedVectorsFile(supervisedVectors.getAbsolutePath()).
|
.inputFile(inputFile.getAbsolutePath())
|
||||||
outputFile(output.getAbsolutePath()).build();
|
.pretrainedVectorsFile(supervisedVectors.getAbsolutePath())
|
||||||
|
.outputFile(output.getAbsolutePath())
|
||||||
|
.build();
|
||||||
|
|
||||||
log.info("\nTraining supervised model ...\n");
|
log.info("\nTraining supervised model ...\n");
|
||||||
fastText.fit();
|
fastText.fit();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWordsStatistics() throws IOException {
|
public void testWordsStatistics() throws IOException {
|
||||||
|
|
||||||
File output = testDir.newFile();
|
File output = testDir.newFile();
|
||||||
|
|
||||||
FastText fastText =
|
FastText fastText = FastText
|
||||||
FastText.builder().supervised(true).
|
.builder()
|
||||||
inputFile(inputFile.getAbsolutePath()).
|
.supervised(true)
|
||||||
outputFile(output.getAbsolutePath()).build();
|
.inputFile(inputFile.getAbsolutePath())
|
||||||
|
.outputFile(output.getAbsolutePath())
|
||||||
|
.build();
|
||||||
|
|
||||||
log.info("\nTraining supervised model ...\n");
|
log.info("\nTraining supervised model ...\n");
|
||||||
fastText.fit();
|
fastText.fit();
|
||||||
|
|
||||||
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(new File(output.getAbsolutePath() + ".vec"));
|
File file = new File(output.getAbsolutePath() + ".vec");
|
||||||
|
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(file);
|
||||||
|
|
||||||
assertEquals(48, word2Vec.getVocab().numWords());
|
assertEquals(48, word2Vec.getVocab().numWords());
|
||||||
|
assertEquals("", 0.1667751520872116, word2Vec.similarity("Football", "teams"), 2e-3);
|
||||||
System.out.println(word2Vec.wordsNearest("association", 3));
|
assertEquals("", 0.10083991289138794, word2Vec.similarity("professional", "minutes"), 2e-3);
|
||||||
System.out.println(word2Vec.similarity("Football", "teams"));
|
assertEquals("", Double.NaN, word2Vec.similarity("java","cpp"), 0.0);
|
||||||
System.out.println(word2Vec.similarity("professional", "minutes"));
|
assertThat(word2Vec.wordsNearest("association", 3), hasItems("Football", "Soccer", "men's"));
|
||||||
System.out.println(word2Vec.similarity("java","cpp"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWordsNativeStatistics() throws IOException {
|
public void testWordsNativeStatistics() {
|
||||||
|
|
||||||
File output = testDir.newFile();
|
|
||||||
|
|
||||||
FastText fastText = new FastText();
|
FastText fastText = new FastText();
|
||||||
fastText.loadPretrainedVectors(supervisedVectors);
|
fastText.loadPretrainedVectors(supervisedVectors);
|
||||||
|
|
||||||
log.info("\nTraining supervised model ...\n");
|
log.info("\nTraining supervised model ...\n");
|
||||||
|
|
||||||
assertEquals(48, fastText.vocab().numWords());
|
assertEquals(48, fastText.vocab().numWords());
|
||||||
|
assertThat(fastText.wordsNearest("association", 3), hasItems("most","eleven","hours"));
|
||||||
String[] result = new String[3];
|
assertEquals(0.1657, fastText.similarity("Football", "teams"), 2e-3);
|
||||||
fastText.wordsNearest("association", 3).toArray(result);
|
assertEquals(0.3661, fastText.similarity("professional", "minutes"), 2e-3);
|
||||||
assertArrayEquals(new String[]{"most","eleven","hours"}, result);
|
assertEquals(Double.NaN, fastText.similarity("java","cpp"), 0.0);
|
||||||
assertEquals(0.1657, fastText.similarity("Football", "teams"), 1e-4);
|
|
||||||
assertEquals(0.3661, fastText.similarity("professional", "minutes"), 1e-4);
|
|
||||||
assertEquals(Double.NaN, fastText.similarity("java","cpp"), 1e-4);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,7 +47,9 @@ import java.io.ByteArrayInputStream;
|
||||||
import java.io.ByteArrayOutputStream;
|
import java.io.ByteArrayOutputStream;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
|
import java.util.concurrent.Callable;
|
||||||
|
|
||||||
|
import static org.awaitility.Awaitility.await;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
|
||||||
|
@ -190,22 +192,26 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
|
||||||
.nOut(4).build())
|
.nOut(4).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
final MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
INDArray w0 = net.getParam("0_W");
|
INDArray w0 = net.getParam("0_W");
|
||||||
assertEquals(w, w0);
|
assertEquals(w, w0);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
ModelSerializer.writeModel(net, baos, true);
|
ModelSerializer.writeModel(net, baos, true);
|
||||||
byte[] bytes = baos.toByteArray();
|
byte[] bytes = baos.toByteArray();
|
||||||
|
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
||||||
MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
||||||
|
|
||||||
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
||||||
assertEquals(net.params(), restored.params());
|
await()
|
||||||
|
.until(new Callable<Boolean>() {
|
||||||
|
@Override
|
||||||
|
public Boolean call() {
|
||||||
|
return net.params().equalsWithEps(restored.params(), 2e-3);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue