arrays = new ArrayList<>();
- long[] vShape = new long[]{1, -1};
- while (iter.hasNext()) {
- if (line.isEmpty())
- line = iter.nextLine();
- String[] split = line.split(" ");
- String word = ReadHelper.decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
- VocabWord word1 = new VocabWord(1.0, word);
+ /* We should check for something that looks like proper word vectors here. i.e: 1 word at the 0
+ * position, and bunch of floats further */
+ String[] headers = line.split(" ");
- word1.setIndex(cache.numWords());
+ try {
+ long[] header = new long[headers.length];
+ for (int x = 0; x < headers.length; x++) {
+ header[x] = Long.parseLong(headers[x]);
+ }
- cache.addToken(word1);
+ /* Now we know, if that's all ints - it's just a header
+ * [0] - number of words
+ * [1] - vectorLength
+ * [2] - number of documents <-- DL4j-only value
+ */
+ if (headers.length == 3) {
+ long numberOfDocuments = header[2];
+ cache.incrementTotalDocCount(numberOfDocuments);
+ }
- cache.addWordToIndex(word1.getIndex(), word);
+ long numWords = header[0];
+ int vectorLength = (int) header[1];
+ printOutProjectedMemoryUse(numWords, vectorLength, 1);
- cache.putVocabWord(word);
-
- float[] vector = new float[split.length - 1];
-
- for (int i = 1; i < split.length; i++) {
- vector[i - 1] = Float.parseFloat(split[i]);
+ return true;
+ } catch (Exception notHeaderException) {
+ // if any conversion exception hits - that'll be considered header
+ return false;
}
-
- vShape[1] = vector.length;
- INDArray row = Nd4j.create(vector, vShape);
-
- arrays.add(row);
-
- // workaround for skipped first row
- line = "";
}
-
- INDArray syn = Nd4j.vstack(arrays);
-
- InMemoryLookupTable lookupTable =
- (InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(arrays.get(0).columns())
- .useAdaGrad(false).cache(cache).useHierarchicSoftmax(false).build();
-
- lookupTable.setSyn0(syn);
-
- iter.close();
-
- try {
- reader.close();
- } catch (Exception e) {
- }
-
- return new Pair<>(lookupTable, (VocabCache) cache);
}
/**
@@ -2352,22 +2411,6 @@ public class WordVectorSerializer {
}
}
- /**
- * This method
- * 1) Binary model, either compressed or not. Like well-known Google Model
- * 2) Popular CSV word2vec text format
- * 3) DL4j compressed format
- *
- * Please note: Only weights will be loaded by this method.
- *
- * @param file
- * @return
- */
- public static Word2Vec readWord2VecModel(@NonNull File file) {
- return readWord2VecModel(file, false);
- }
-
-
/**
* This method
* 1) Binary model, either compressed or not. Like well-known Google Model
@@ -2389,106 +2432,196 @@ public class WordVectorSerializer {
* 2) Popular CSV word2vec text format
* 3) DL4j compressed format
*
- * Please note: if extended data isn't available, only weights will be loaded instead.
+ * Please note: Only weights will be loaded by this method.
*
- * @param path
- * @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
+ * @param path path to model file
+ * @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
* @return
*/
public static Word2Vec readWord2VecModel(String path, boolean extendedModel) {
return readWord2VecModel(new File(path), extendedModel);
}
- public static Word2Vec readAsBinaryNoLineBreaks(@NonNull File file) {
+ /**
+ * This method
+ * 1) Binary model, either compressed or not. Like well-known Google Model
+ * 2) Popular CSV word2vec text format
+ * 3) DL4j compressed format
+ *
+ * Please note: Only weights will be loaded by this method.
+ *
+ * @param file
+ * @return
+ */
+ public static Word2Vec readWord2VecModel(File file) {
+ return readWord2VecModel(file, false);
+ }
+
+ /**
+ * This method
+ * 1) Binary model, either compressed or not. Like well-known Google Model
+ * 2) Popular CSV word2vec text format
+ * 3) DL4j compressed format
+ *
+ * Please note: if extended data isn't available, only weights will be loaded instead.
+ *
+ * @param file model file
+ * @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
+ * @return word2vec model
+ */
+ public static Word2Vec readWord2VecModel(File file, boolean extendedModel) {
+ if (!file.exists() || !file.isFile()) {
+ throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
+ }
+ boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
+ if (originalPeriodic) {
+ Nd4j.getMemoryManager().togglePeriodicGc(false);
+ }
+ Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
+
+ try {
+ return readWord2Vec(file, extendedModel);
+ } catch (Exception readSequenceVectors) {
+ try {
+ return extendedModel
+ ? readAsExtendedModel(file)
+ : readAsSimplifiedModel(file);
+ } catch (Exception loadFromFileException) {
+ try {
+ return readAsCsv(file);
+ } catch (Exception readCsvException) {
+ try {
+ return readAsBinary(file);
+ } catch (Exception readBinaryException) {
+ try {
+ return readAsBinaryNoLineBreaks(file);
+ } catch (Exception readModelException) {
+ log.error("Unable to guess input file format", readModelException);
+ throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
+ }
+ }
+ }
+ }
+ }
+ }
+
+ public static Word2Vec readAsBinaryNoLineBreaks(@NonNull File file) {
+ try (InputStream inputStream = fileStream(file)) {
+ return readAsBinaryNoLineBreaks(inputStream);
+ } catch (IOException readCsvException) {
+ throw new RuntimeException(readCsvException);
+ }
+ }
+
+ public static Word2Vec readAsBinaryNoLineBreaks(@NonNull InputStream inputStream) {
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
- Word2Vec vec;
// try to load without linebreaks
try {
- if (originalPeriodic)
+ if (originalPeriodic) {
Nd4j.getMemoryManager().togglePeriodicGc(true);
+ }
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
- vec = readBinaryModel(file, false, false);
- return vec;
- } catch (Exception ez) {
- throw new RuntimeException(
- "Unable to guess input file format. Please use corresponding loader directly");
+ return readBinaryModel(inputStream, false, false);
+ } catch (Exception readModelException) {
+ log.error("Cannot read binary model", readModelException);
+ throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
+ }
+ }
+
+ public static Word2Vec readAsBinary(@NonNull File file) {
+ try (InputStream inputStream = fileStream(file)) {
+ return readAsBinary(inputStream);
+ } catch (IOException readCsvException) {
+ throw new RuntimeException(readCsvException);
}
}
/**
- * This method loads Word2Vec model from binary file
+ * This method loads Word2Vec model from binary input stream.
*
- * @param file File
- * @return Word2Vec
+ * @param inputStream binary input stream
+ * @return Word2Vec
*/
- public static Word2Vec readAsBinary(@NonNull File file) {
+ public static Word2Vec readAsBinary(@NonNull InputStream inputStream) {
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
- Word2Vec vec;
-
// we fallback to trying binary model instead
try {
log.debug("Trying binary model restoration...");
- if (originalPeriodic)
+ if (originalPeriodic) {
Nd4j.getMemoryManager().togglePeriodicGc(true);
+ }
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
- vec = readBinaryModel(file, true, false);
- return vec;
- } catch (Exception ey) {
- throw new RuntimeException(ey);
+ return readBinaryModel(inputStream, true, false);
+ } catch (Exception readModelException) {
+ throw new RuntimeException(readModelException);
+ }
+ }
+
+ public static Word2Vec readAsCsv(@NonNull File file) {
+ try (InputStream inputStream = fileStream(file)) {
+ return readAsCsv(inputStream);
+ } catch (IOException readCsvException) {
+ throw new RuntimeException(readCsvException);
}
}
/**
* This method loads Word2Vec model from csv file
*
- * @param file File
- * @return Word2Vec
+ * @param inputStream input stream
+ * @return Word2Vec model
*/
- public static Word2Vec readAsCsv(@NonNull File file) {
-
- Word2Vec vec;
+ public static Word2Vec readAsCsv(@NonNull InputStream inputStream) {
VectorsConfiguration configuration = new VectorsConfiguration();
// let's try to load this file as csv file
try {
log.debug("Trying CSV model restoration...");
- Pair pair = loadTxt(file);
- Word2Vec.Builder builder = new Word2Vec.Builder().lookupTable(pair.getFirst()).useAdaGrad(false)
- .vocabCache(pair.getSecond()).layerSize(pair.getFirst().layerSize())
+ Pair pair = loadTxt(inputStream);
+ Word2Vec.Builder builder = new Word2Vec
+ .Builder()
+ .lookupTable(pair.getFirst())
+ .useAdaGrad(false)
+ .vocabCache(pair.getSecond())
+ .layerSize(pair.getFirst().layerSize())
// we don't use hs here, because model is incomplete
- .useHierarchicSoftmax(false).resetModel(false);
+ .useHierarchicSoftmax(false)
+ .resetModel(false);
TokenizerFactory factory = getTokenizerFactory(configuration);
- if (factory != null)
+ if (factory != null) {
builder.tokenizerFactory(factory);
+ }
- vec = builder.build();
- return vec;
+ return builder.build();
} catch (Exception ex) {
throw new RuntimeException("Unable to load model in CSV format");
}
}
+ /**
+ * This method just loads full compressed model.
+ */
private static Word2Vec readAsExtendedModel(@NonNull File file) throws IOException {
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
log.debug("Trying full model restoration...");
- // this method just loads full compressed model
- if (originalPeriodic)
+ if (originalPeriodic) {
Nd4j.getMemoryManager().togglePeriodicGc(true);
+ }
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
@@ -2627,67 +2760,6 @@ public class WordVectorSerializer {
return vec;
}
- /**
- * This method
- * 1) Binary model, either compressed or not. Like well-known Google Model
- * 2) Popular CSV word2vec text format
- * 3) DL4j compressed format
- *
- * Please note: if extended data isn't available, only weights will be loaded instead.
- *
- * @param file
- * @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
- * @return
- */
- public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedModel) {
-
- if (!file.exists() || !file.isFile())
- throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
-
- Word2Vec vec = null;
-
- int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
- boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
- if (originalPeriodic)
- Nd4j.getMemoryManager().togglePeriodicGc(false);
- Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
-
- // try to load zip format
- try {
- vec = readWord2Vec(file, extendedModel);
- return vec;
- } catch (Exception e) {
- // let's try to load this file as csv file
- try {
- if (extendedModel) {
- vec = readAsExtendedModel(file);
- return vec;
- } else {
- vec = readAsSimplifiedModel(file);
- return vec;
- }
- } catch (Exception ex) {
- try {
- vec = readAsCsv(file);
- return vec;
- } catch (Exception exc) {
- try {
- vec = readAsBinary(file);
- return vec;
- } catch (Exception exce) {
- try {
- vec = readAsBinaryNoLineBreaks(file);
- return vec;
-
- } catch (Exception excep) {
- throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
- }
- }
- }
- }
- }
- }
-
protected static TokenizerFactory getTokenizerFactory(VectorsConfiguration configuration) {
if (configuration == null)
return null;
@@ -3019,16 +3091,13 @@ public class WordVectorSerializer {
/**
* This method restores Word2Vec model from file
*
- * @param path String
- * @param readExtendedTables booleab
+ * @param path
+ * @param readExtendedTables
* @return Word2Vec
*/
- public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables)
- throws IOException {
-
+ public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables) {
File file = new File(path);
- Word2Vec word2Vec = readWord2Vec(file, readExtendedTables);
- return word2Vec;
+ return readWord2Vec(file, readExtendedTables);
}
/**
@@ -3139,11 +3208,12 @@ public class WordVectorSerializer {
* @param readExtendedTables boolean
* @return Word2Vec
*/
- public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables)
- throws IOException {
-
- Word2Vec word2Vec = readWord2Vec(new FileInputStream(file), readExtendedTables);
- return word2Vec;
+ public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables) {
+ try (InputStream inputStream = fileStream(file)) {
+ return readWord2Vec(inputStream, readExtendedTables);
+ } catch (Exception readSequenceVectors) {
+ throw new RuntimeException(readSequenceVectors);
+ }
}
/**
@@ -3153,13 +3223,19 @@ public class WordVectorSerializer {
* @param readExtendedTable boolean
* @return Word2Vec
*/
- public static Word2Vec readWord2Vec(@NonNull InputStream stream,
- boolean readExtendedTable) throws IOException {
+ public static Word2Vec readWord2Vec(
+ @NonNull InputStream stream,
+ boolean readExtendedTable) throws IOException {
SequenceVectors vectors = readSequenceVectors(stream, readExtendedTable);
- Word2Vec word2Vec = new Word2Vec.Builder(vectors.getConfiguration()).layerSize(vectors.getLayerSize()).build();
+
+ Word2Vec word2Vec = new Word2Vec
+ .Builder(vectors.getConfiguration())
+ .layerSize(vectors.getLayerSize())
+ .build();
word2Vec.setVocab(vectors.getVocab());
word2Vec.setLookupTable(vectors.lookupTable());
word2Vec.setModelUtils(vectors.getModelUtils());
+
return word2Vec;
}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java
index 69fcd236c..5466bc15b 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java
@@ -37,8 +37,6 @@ import java.io.File;
import java.util.ArrayList;
import java.util.List;
-import static org.junit.Assert.assertEquals;
-
@Slf4j
public class TsneTest extends BaseDL4JTest {
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializerTest.java
similarity index 86%
rename from deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/WordVectorSerializerTest.java
rename to deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializerTest.java
index b7aff923e..f089a6ae9 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/WordVectorSerializerTest.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializerTest.java
@@ -14,17 +14,14 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
-package org.deeplearning4j.models.sequencevectors.serialization;
+package org.deeplearning4j.models.embeddings.loader;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
-import org.apache.commons.lang.StringUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
-import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
-import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
import org.deeplearning4j.models.fasttext.FastText;
@@ -47,7 +44,11 @@ import java.io.File;
import java.io.IOException;
import java.util.Collections;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
@Slf4j
public class WordVectorSerializerTest extends BaseDL4JTest {
@@ -78,10 +79,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
- InMemoryLookupTable lookupTable =
- (InMemoryLookupTable) new InMemoryLookupTable.Builder()
- .useAdaGrad(false).cache(cache)
- .build();
+ InMemoryLookupTable lookupTable = new InMemoryLookupTable
+ .Builder()
+ .useAdaGrad(false)
+ .cache(cache)
+ .build();
lookupTable.setSyn0(syn0);
lookupTable.setSyn1(syn1);
@@ -92,7 +94,6 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
lookupTable(lookupTable).
build();
SequenceVectors deser = null;
- String json = StringUtils.EMPTY;
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
WordVectorSerializer.writeSequenceVectors(vectors, baos);
@@ -126,10 +127,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
- InMemoryLookupTable lookupTable =
- (InMemoryLookupTable) new InMemoryLookupTable.Builder()
- .useAdaGrad(false).cache(cache)
- .build();
+ InMemoryLookupTable lookupTable = new InMemoryLookupTable
+ .Builder()
+ .useAdaGrad(false)
+ .cache(cache)
+ .build();
lookupTable.setSyn0(syn0);
lookupTable.setSyn1(syn1);
@@ -204,10 +206,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
- InMemoryLookupTable lookupTable =
- (InMemoryLookupTable) new InMemoryLookupTable.Builder()
- .useAdaGrad(false).cache(cache)
- .build();
+ InMemoryLookupTable lookupTable = new InMemoryLookupTable
+ .Builder()
+ .useAdaGrad(false)
+ .cache(cache)
+ .build();
lookupTable.setSyn0(syn0);
lookupTable.setSyn1(syn1);
@@ -252,10 +255,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
- InMemoryLookupTable lookupTable =
- (InMemoryLookupTable) new InMemoryLookupTable.Builder()
- .useAdaGrad(false).cache(cache)
- .build();
+ InMemoryLookupTable lookupTable = new InMemoryLookupTable
+ .Builder()
+ .useAdaGrad(false)
+ .cache(cache)
+ .build();
lookupTable.setSyn0(syn0);
lookupTable.setSyn1(syn1);
@@ -267,7 +271,6 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
WeightLookupTable deser = null;
try {
WordVectorSerializer.writeLookupTable(lookupTable, file);
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
deser = WordVectorSerializer.readLookupTable(file);
} catch (Exception e) {
log.error("",e);
@@ -305,7 +308,6 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
FastText deser = null;
try {
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
deser = WordVectorSerializer.readWordVectors(new File(dir, "some.data"));
} catch (Exception e) {
log.error("",e);
@@ -323,4 +325,32 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
assertEquals(fastText.getInputFile(), deser.getInputFile());
assertEquals(fastText.getOutputFile(), deser.getOutputFile());
}
+
+ @Test
+ public void testIsHeader_withValidHeader () {
+
+ /* Given */
+ AbstractCache cache = new AbstractCache<>();
+ String line = "48 100";
+
+ /* When */
+ boolean isHeader = WordVectorSerializer.isHeader(line, cache);
+
+ /* Then */
+ assertTrue(isHeader);
+ }
+
+ @Test
+ public void testIsHeader_notHeader () {
+
+ /* Given */
+ AbstractCache cache = new AbstractCache<>();
+ String line = "your -0.0017603 0.0030831 0.00069072 0.0020581 -0.0050952 -2.2573e-05 -0.001141";
+
+ /* When */
+ boolean isHeader = WordVectorSerializer.isHeader(line, cache);
+
+ /* Then */
+ assertFalse(isHeader);
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java
index 4f0548ef5..4c89cfa1c 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java
@@ -1,9 +1,9 @@
package org.deeplearning4j.models.fasttext;
import lombok.extern.slf4j.Slf4j;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
-import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.junit.Rule;
@@ -14,13 +14,14 @@ import org.nd4j.common.primitives.Pair;
import org.nd4j.common.resources.Resources;
import java.io.File;
+import java.io.FileNotFoundException;
import java.io.IOException;
-
+import static org.hamcrest.CoreMatchers.hasItems;
+import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
-
@Slf4j
public class FastTextTest extends BaseDL4JTest {
@@ -32,7 +33,6 @@ public class FastTextTest extends BaseDL4JTest {
private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin");
private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec");
-
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@@ -90,7 +90,7 @@ public class FastTextTest extends BaseDL4JTest {
}
@Test
- public void tesLoadCBOWModel() throws IOException {
+ public void tesLoadCBOWModel() {
FastText fastText = new FastText(cbowModelFile);
fastText.test(cbowModelFile);
@@ -99,7 +99,7 @@ public class FastTextTest extends BaseDL4JTest {
assertEquals("enjoy", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
double[] expected = {5.040466203354299E-4, 0.001005030469968915, 2.8882650076411664E-4, -6.413314840756357E-4, -1.78931062691845E-4, -0.0023157168179750443, -0.002215880434960127, 0.00274421414360404, -1.5344757412094623E-4, 4.6274057240225375E-4, -1.4383681991603225E-4, 3.7832374800927937E-4, 2.523412986192852E-4, 0.0018913350068032742, -0.0024741862434893847, -4.976555937901139E-4, 0.0039220210164785385, -0.001781729981303215, -6.010578363202512E-4, -0.00244093406945467, -7.98621098510921E-4, -0.0010007203090935946, -0.001640203408896923, 7.897148607298732E-4, 9.131592814810574E-4, -0.0013367272913455963, -0.0014030139427632093, -7.755287806503475E-4, -4.2878396925516427E-4, 6.912827957421541E-4, -0.0011824817629531026, -0.0036014916840940714, 0.004353308118879795, -7.073904271237552E-5, -9.646290563978255E-4, -0.0031849315855652094, 2.3360115301329643E-4, -2.9103990527801216E-4, -0.0022990566212683916, -0.002393763978034258, -0.001034979010000825, -0.0010725988540798426, 0.0018285386031493545, -0.0013178540393710136, -1.6632364713586867E-4, -1.4665909475297667E-5, 5.445032729767263E-4, 2.999933494720608E-4, -0.0014367225812748075, -0.002345481887459755, 0.001117417006753385, -8.688368834555149E-4, -0.001830018823966384, 0.0013242220738902688, -8.880519890226424E-4, -6.888324278406799E-4, -0.0036394784692674875, 0.002179111586883664, -1.7201311129610986E-4, 0.002365073887631297, 0.002688770182430744, 0.0023955567739903927, 0.001469283364713192, 0.0011803617235273123, 5.871498142369092E-4, -7.099180947989225E-4, 7.518937345594168E-4, -8.599072461947799E-4, -6.600041524507105E-4, -0.002724145073443651, -8.365285466425121E-4, 0.0013173354091122746, 0.001083166105672717, 0.0014539906987920403, -3.1698777456767857E-4, -2.387022686889395E-4, 1.9560157670639455E-4, 0.0020277926232665777, -0.0012741144746541977, -0.0013026101514697075, -1.5212174912448972E-4, 0.0014194383984431624, 0.0012500399025157094, 0.0013362085446715355, 3.692879108712077E-4, 4.319801155361347E-5, 0.0011261265026405454, 0.0017244465416297317, 5.564604725805111E-5, 0.002170475199818611, 0.0014707016525790095, 0.001303741242736578, 0.005553730763494968, -0.0011097051901742816, -0.0013661726843565702, 0.0014100460102781653, 0.0011811562580987811, -6.622733199037611E-4, 7.860265322960913E-4, -9.811905911192298E-4};
- assertArrayEquals(expected, fastText.getWordVector("enjoy"), 1e-4);
+ assertArrayEquals(expected, fastText.getWordVector("enjoy"), 2e-3);
}
@Test
@@ -111,7 +111,7 @@ public class FastTextTest extends BaseDL4JTest {
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
- assertArrayEquals(expected, fastText.getWordVector("association"), 1e-4);
+ assertArrayEquals(expected, fastText.getWordVector("association"), 2e-3);
String label = fastText.predict(text);
assertEquals("__label__soccer", label);
@@ -126,7 +126,7 @@ public class FastTextTest extends BaseDL4JTest {
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
- assertArrayEquals(expected, fastText.getWordVector("association"), 1e-4);
+ assertArrayEquals(expected, fastText.getWordVector("association"), 2e-3);
String label = fastText.predict(text);
fastText.wordsNearest("test",1);
@@ -140,10 +140,10 @@ public class FastTextTest extends BaseDL4JTest {
Pair result = fastText.predictProbability(text);
assertEquals("__label__soccer", result.getFirst());
- assertEquals(-0.6930, result.getSecond(), 1e-4);
+ assertEquals(-0.6930, result.getSecond(), 2e-3);
assertEquals(48, fastText.vocabSize());
- assertEquals(0.0500, fastText.getLearningRate(), 1e-4);
+ assertEquals(0.0500, fastText.getLearningRate(), 2e-3);
assertEquals(100, fastText.getDimension());
assertEquals(5, fastText.getContextWindowSize());
assertEquals(5, fastText.getEpoch());
@@ -155,7 +155,7 @@ public class FastTextTest extends BaseDL4JTest {
}
@Test
- public void testVocabulary() throws IOException {
+ public void testVocabulary() {
FastText fastText = new FastText(supModelFile);
assertEquals(48, fastText.vocab().numWords());
assertEquals(48, fastText.vocabSize());
@@ -171,78 +171,73 @@ public class FastTextTest extends BaseDL4JTest {
}
@Test
- public void testLoadIterator() {
- try {
- SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
- FastText fastText =
- FastText.builder().supervised(true).iterator(iter).build();
- fastText.loadIterator();
-
- } catch (IOException e) {
- log.error("",e);
- }
+ public void testLoadIterator() throws FileNotFoundException {
+ SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
+ FastText
+ .builder()
+ .supervised(true)
+ .iterator(iter)
+ .build()
+ .loadIterator();
}
@Test(expected=IllegalStateException.class)
public void testState() {
FastText fastText = new FastText();
- String label = fastText.predict("something");
+ fastText.predict("something");
}
@Test
public void testPretrainedVectors() throws IOException {
File output = testDir.newFile();
- FastText fastText =
- FastText.builder().supervised(true).
- inputFile(inputFile.getAbsolutePath()).
- pretrainedVectorsFile(supervisedVectors.getAbsolutePath()).
- outputFile(output.getAbsolutePath()).build();
+ FastText fastText = FastText
+ .builder()
+ .supervised(true)
+ .inputFile(inputFile.getAbsolutePath())
+ .pretrainedVectorsFile(supervisedVectors.getAbsolutePath())
+ .outputFile(output.getAbsolutePath())
+ .build();
+
log.info("\nTraining supervised model ...\n");
fastText.fit();
}
@Test
public void testWordsStatistics() throws IOException {
-
File output = testDir.newFile();
- FastText fastText =
- FastText.builder().supervised(true).
- inputFile(inputFile.getAbsolutePath()).
- outputFile(output.getAbsolutePath()).build();
+ FastText fastText = FastText
+ .builder()
+ .supervised(true)
+ .inputFile(inputFile.getAbsolutePath())
+ .outputFile(output.getAbsolutePath())
+ .build();
log.info("\nTraining supervised model ...\n");
fastText.fit();
- Word2Vec word2Vec = WordVectorSerializer.readAsCsv(new File(output.getAbsolutePath() + ".vec"));
+ File file = new File(output.getAbsolutePath() + ".vec");
+ Word2Vec word2Vec = WordVectorSerializer.readAsCsv(file);
- assertEquals(48, word2Vec.getVocab().numWords());
-
- System.out.println(word2Vec.wordsNearest("association", 3));
- System.out.println(word2Vec.similarity("Football", "teams"));
- System.out.println(word2Vec.similarity("professional", "minutes"));
- System.out.println(word2Vec.similarity("java","cpp"));
+ assertEquals(48, word2Vec.getVocab().numWords());
+ assertEquals("", 0.1667751520872116, word2Vec.similarity("Football", "teams"), 2e-3);
+ assertEquals("", 0.10083991289138794, word2Vec.similarity("professional", "minutes"), 2e-3);
+ assertEquals("", Double.NaN, word2Vec.similarity("java","cpp"), 0.0);
+ assertThat(word2Vec.wordsNearest("association", 3), hasItems("Football", "Soccer", "men's"));
}
-
@Test
- public void testWordsNativeStatistics() throws IOException {
-
- File output = testDir.newFile();
-
+ public void testWordsNativeStatistics() {
FastText fastText = new FastText();
fastText.loadPretrainedVectors(supervisedVectors);
log.info("\nTraining supervised model ...\n");
assertEquals(48, fastText.vocab().numWords());
-
- String[] result = new String[3];
- fastText.wordsNearest("association", 3).toArray(result);
- assertArrayEquals(new String[]{"most","eleven","hours"}, result);
- assertEquals(0.1657, fastText.similarity("Football", "teams"), 1e-4);
- assertEquals(0.3661, fastText.similarity("professional", "minutes"), 1e-4);
- assertEquals(Double.NaN, fastText.similarity("java","cpp"), 1e-4);
+ assertThat(fastText.wordsNearest("association", 3), hasItems("most","eleven","hours"));
+ assertEquals(0.1657, fastText.similarity("Football", "teams"), 2e-3);
+ assertEquals(0.3661, fastText.similarity("professional", "minutes"), 2e-3);
+ assertEquals(Double.NaN, fastText.similarity("java","cpp"), 0.0);
}
}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java
index c9cc8f072..38b44d1ff 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java
@@ -47,7 +47,9 @@ import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.util.Collection;
+import java.util.concurrent.Callable;
+import static org.awaitility.Awaitility.await;
import static org.junit.Assert.assertEquals;
@@ -190,22 +192,26 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
.nOut(4).build())
.build();
- MultiLayerNetwork net = new MultiLayerNetwork(conf);
+ final MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray w0 = net.getParam("0_W");
assertEquals(w, w0);
-
-
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true);
byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
- MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
+ final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
- assertEquals(net.params(), restored.params());
+ await()
+ .until(new Callable() {
+ @Override
+ public Boolean call() {
+ return net.params().equalsWithEps(restored.params(), 2e-3);
+ }
+ });
}
}