Various fixes (#141)
* #8121 CnnSentenceDataSetIterator fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8120 CnnSentenceDataSetIterator.loadSingleSentence no words UX/exception improvement Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8122 AggregatingSentenceIterator builder - addSentencePreProcessor -> sentencePreProcessor Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8082 Arbiter - fix GridSearchCandidateGenerator search size issue Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
0adce9a4fa
commit
348d9c59f7
|
@ -19,6 +19,7 @@ package org.deeplearning4j.arbiter.optimize.api.adapter;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
@ -37,6 +38,8 @@ public abstract class ParameterSpaceAdapter<F, T> implements ParameterSpace<T> {
|
||||||
|
|
||||||
protected abstract ParameterSpace<F> underlying();
|
protected abstract ParameterSpace<F> underlying();
|
||||||
|
|
||||||
|
protected abstract String underlyingName();
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public T getValue(double[] parameterValues) {
|
public T getValue(double[] parameterValues) {
|
||||||
|
@ -50,17 +53,21 @@ public abstract class ParameterSpaceAdapter<F, T> implements ParameterSpace<T> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<ParameterSpace> collectLeaves() {
|
public List<ParameterSpace> collectLeaves() {
|
||||||
|
ParameterSpace p = underlying();
|
||||||
|
if(p.isLeaf()){
|
||||||
|
return Collections.singletonList(p);
|
||||||
|
}
|
||||||
return underlying().collectLeaves();
|
return underlying().collectLeaves();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
public Map<String, ParameterSpace> getNestedSpaces() {
|
||||||
return underlying().getNestedSpaces();
|
return Collections.singletonMap(underlyingName(), (ParameterSpace)underlying());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isLeaf() {
|
public boolean isLeaf() {
|
||||||
return underlying().isLeaf();
|
return false; //Underlying may be a leaf, however
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -17,10 +17,12 @@
|
||||||
package org.deeplearning4j.arbiter.optimize.generator;
|
package org.deeplearning4j.arbiter.optimize.generator;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.Getter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.math3.random.RandomAdaptor;
|
import org.apache.commons.math3.random.RandomAdaptor;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
|
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
|
||||||
import org.deeplearning4j.arbiter.util.LeafUtils;
|
import org.deeplearning4j.arbiter.util.LeafUtils;
|
||||||
|
@ -65,6 +67,7 @@ public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
|
||||||
private final Mode mode;
|
private final Mode mode;
|
||||||
|
|
||||||
private int[] numValuesPerParam;
|
private int[] numValuesPerParam;
|
||||||
|
@Getter
|
||||||
private int totalNumCandidates;
|
private int totalNumCandidates;
|
||||||
private Queue<Integer> order;
|
private Queue<Integer> order;
|
||||||
|
|
||||||
|
@ -123,6 +126,8 @@ public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
|
||||||
int max = ips.getMax();
|
int max = ips.getMax();
|
||||||
//Discretize, as some integer ranges are much too large to search (i.e., num. neural network units, between 100 and 1000)
|
//Discretize, as some integer ranges are much too large to search (i.e., num. neural network units, between 100 and 1000)
|
||||||
numValuesPerParam[i] = Math.min(max - min + 1, discretizationCount);
|
numValuesPerParam[i] = Math.min(max - min + 1, discretizationCount);
|
||||||
|
} else if (ps instanceof FixedValue){
|
||||||
|
numValuesPerParam[i] = 1;
|
||||||
} else {
|
} else {
|
||||||
numValuesPerParam[i] = discretizationCount;
|
numValuesPerParam[i] = discretizationCount;
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,10 +44,7 @@ import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This is an abstract ParameterSpace for both MultiLayerNetworks (MultiLayerSpace) and ComputationGraph (ComputationGraphSpace)
|
* This is an abstract ParameterSpace for both MultiLayerNetworks (MultiLayerSpace) and ComputationGraph (ComputationGraphSpace)
|
||||||
|
@ -212,18 +209,30 @@ public abstract class BaseNetworkSpace<T> extends AbstractParameterSpace<T> {
|
||||||
@Override
|
@Override
|
||||||
public List<ParameterSpace> collectLeaves() {
|
public List<ParameterSpace> collectLeaves() {
|
||||||
Map<String, ParameterSpace> global = getNestedSpaces();
|
Map<String, ParameterSpace> global = getNestedSpaces();
|
||||||
List<ParameterSpace> list = new ArrayList<>();
|
|
||||||
list.addAll(global.values());
|
|
||||||
|
|
||||||
//Note: Results on previous line does NOT include the LayerSpaces, therefore we need to add these manually...
|
//Note: Results on previous line does NOT include the LayerSpaces, therefore we need to add these manually...
|
||||||
//This is because the type is a list, not a ParameterSpace
|
//This is because the type is a list, not a ParameterSpace
|
||||||
|
LinkedList<ParameterSpace> stack = new LinkedList<>();
|
||||||
|
stack.add(this);
|
||||||
|
|
||||||
for (LayerConf layerConf : layerSpaces) {
|
for (LayerConf layerConf : layerSpaces) {
|
||||||
LayerSpace ls = layerConf.getLayerSpace();
|
LayerSpace ls = layerConf.getLayerSpace();
|
||||||
list.addAll(ls.collectLeaves());
|
stack.addAll(ls.collectLeaves());
|
||||||
}
|
}
|
||||||
|
|
||||||
return list;
|
List<ParameterSpace> out = new ArrayList<>();
|
||||||
|
while (!stack.isEmpty()) {
|
||||||
|
ParameterSpace next = stack.removeLast();
|
||||||
|
if (next.isLeaf()) {
|
||||||
|
out.add(next);
|
||||||
|
} else {
|
||||||
|
Map<String, ParameterSpace> m = next.getNestedSpaces();
|
||||||
|
ParameterSpace[] arr = m.values().toArray(new ParameterSpace[m.size()]);
|
||||||
|
for (int i = arr.length - 1; i >= 0; i--) {
|
||||||
|
stack.add(arr[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -84,8 +84,10 @@ public class MultiLayerSpace extends BaseNetworkSpace<DL4JConfiguration> {
|
||||||
List<ParameterSpace> allLeaves = collectLeaves();
|
List<ParameterSpace> allLeaves = collectLeaves();
|
||||||
List<ParameterSpace> list = LeafUtils.getUniqueObjects(allLeaves);
|
List<ParameterSpace> list = LeafUtils.getUniqueObjects(allLeaves);
|
||||||
|
|
||||||
for (ParameterSpace ps : list)
|
for (ParameterSpace ps : list) {
|
||||||
|
int n = ps.numParameters();
|
||||||
numParameters += ps.numParameters();
|
numParameters += ps.numParameters();
|
||||||
|
}
|
||||||
|
|
||||||
this.trainingWorkspaceMode = builder.trainingWorkspaceMode;
|
this.trainingWorkspaceMode = builder.trainingWorkspaceMode;
|
||||||
this.inferenceWorkspaceMode = builder.inferenceWorkspaceMode;
|
this.inferenceWorkspaceMode = builder.inferenceWorkspaceMode;
|
||||||
|
|
|
@ -50,4 +50,9 @@ public class ActivationParameterSpaceAdapter extends ParameterSpaceAdapter<Activ
|
||||||
protected ParameterSpace<Activation> underlying() {
|
protected ParameterSpace<Activation> underlying() {
|
||||||
return activation;
|
return activation;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String underlyingName() {
|
||||||
|
return "activation";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,4 +52,9 @@ public class LossFunctionParameterSpaceAdapter
|
||||||
protected ParameterSpace<LossFunctions.LossFunction> underlying() {
|
protected ParameterSpace<LossFunctions.LossFunction> underlying() {
|
||||||
return lossFunction;
|
return lossFunction;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String underlyingName() {
|
||||||
|
return "lossFunction";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,9 +52,11 @@ public class DropoutLayerSpace extends LayerSpace<DropoutLayer> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<ParameterSpace> collectLeaves() {
|
public List<ParameterSpace> collectLeaves() {
|
||||||
return Collections.<ParameterSpace>singletonList(dropOut);
|
return dropOut.collectLeaves();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isLeaf() {
|
public boolean isLeaf() {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -31,6 +31,7 @@ import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
|
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
|
||||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
|
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
|
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
|
||||||
|
@ -706,4 +707,63 @@ public class TestMultiLayerSpace {
|
||||||
|
|
||||||
MultiLayerConfiguration conf = mls.getValue(new double[nParams]).getMultiLayerConfiguration();
|
MultiLayerConfiguration conf = mls.getValue(new double[nParams]).getMultiLayerConfiguration();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIssue8082(){
|
||||||
|
ParameterSpace<Double> learningRateHyperparam = new DiscreteParameterSpace<>(0.003, 0.005, 0.01, 0.05);
|
||||||
|
ParameterSpace<Integer> layerSizeHyperparam1 = new DiscreteParameterSpace<>(32, 64, 96, 128);
|
||||||
|
ParameterSpace<Integer> layerSizeHyperparam2 = new DiscreteParameterSpace<>(32, 64, 96, 128);
|
||||||
|
ParameterSpace<Double> dropoutHyperparam = new DiscreteParameterSpace<>(0.8, 0.9);
|
||||||
|
|
||||||
|
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
||||||
|
.updater(new AdamSpace(learningRateHyperparam))
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.l2(0.0001)
|
||||||
|
.addLayer(new DenseLayerSpace.Builder()
|
||||||
|
.nIn(10)
|
||||||
|
.nOut(layerSizeHyperparam1)
|
||||||
|
.build())
|
||||||
|
.addLayer(new BatchNormalizationSpace.Builder()
|
||||||
|
.nOut(layerSizeHyperparam1)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.build())
|
||||||
|
.addLayer(new DropoutLayerSpace.Builder()
|
||||||
|
.dropOut(dropoutHyperparam)
|
||||||
|
.build())
|
||||||
|
.addLayer(new DenseLayerSpace.Builder()
|
||||||
|
.nOut(layerSizeHyperparam2)
|
||||||
|
.build())
|
||||||
|
.addLayer(new BatchNormalizationSpace.Builder()
|
||||||
|
.nOut(layerSizeHyperparam2)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.build())
|
||||||
|
.addLayer(new DropoutLayerSpace.Builder()
|
||||||
|
.dropOut(dropoutHyperparam)
|
||||||
|
.build())
|
||||||
|
.addLayer(new OutputLayerSpace.Builder()
|
||||||
|
.nOut(10)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.lossFunction(LossFunction.MCXENT)
|
||||||
|
.build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
assertEquals(4, mls.getNumParameters());
|
||||||
|
|
||||||
|
for( int discreteCount : new int[]{1, 5}) {
|
||||||
|
GridSearchCandidateGenerator generator = new GridSearchCandidateGenerator(mls, discreteCount, GridSearchCandidateGenerator.Mode.Sequential, null);
|
||||||
|
|
||||||
|
int expCandidates = 4 * 4 * 4 * 2;
|
||||||
|
assertEquals(expCandidates, generator.getTotalNumCandidates());
|
||||||
|
|
||||||
|
int count = 0;
|
||||||
|
while (generator.hasMoreCandidates()) {
|
||||||
|
generator.getCandidate();
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
assertEquals(expCandidates, count);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -117,18 +117,22 @@ public class CnnSentenceDataSetIterator implements DataSetIterator {
|
||||||
List<String> sortedLabels = new ArrayList<>(this.sentenceProvider.allLabels());
|
List<String> sortedLabels = new ArrayList<>(this.sentenceProvider.allLabels());
|
||||||
Collections.sort(sortedLabels);
|
Collections.sort(sortedLabels);
|
||||||
|
|
||||||
|
this.wordVectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;
|
||||||
|
|
||||||
for (String s : sortedLabels) {
|
for (String s : sortedLabels) {
|
||||||
this.labelClassMap.put(s, count++);
|
this.labelClassMap.put(s, count++);
|
||||||
}
|
}
|
||||||
if (unknownWordHandling == UnknownWordHandling.UseUnknownVector) {
|
if (unknownWordHandling == UnknownWordHandling.UseUnknownVector) {
|
||||||
if (useNormalizedWordVectors) {
|
if (useNormalizedWordVectors) {
|
||||||
wordVectors.getWordVectorMatrixNormalized(wordVectors.getUNK());
|
unknown = wordVectors.getWordVectorMatrixNormalized(wordVectors.getUNK());
|
||||||
} else {
|
} else {
|
||||||
wordVectors.getWordVectorMatrix(wordVectors.getUNK());
|
unknown = wordVectors.getWordVectorMatrix(wordVectors.getUNK());
|
||||||
|
}
|
||||||
|
|
||||||
|
if(unknown == null){
|
||||||
|
unknown = wordVectors.getWordVectorMatrix(wordVectors.vocab().wordAtIndex(0)).like();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
this.wordVectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -136,6 +140,9 @@ public class CnnSentenceDataSetIterator implements DataSetIterator {
|
||||||
*/
|
*/
|
||||||
public INDArray loadSingleSentence(String sentence) {
|
public INDArray loadSingleSentence(String sentence) {
|
||||||
List<String> tokens = tokenizeSentence(sentence);
|
List<String> tokens = tokenizeSentence(sentence);
|
||||||
|
if(tokens.isEmpty())
|
||||||
|
throw new IllegalStateException("No tokens available for input sentence - empty string or no words in vocabulary with RemoveWord unknown handling? Sentence = \"" +
|
||||||
|
sentence + "\"");
|
||||||
if(format == Format.CNN1D || format == Format.RNN){
|
if(format == Format.CNN1D || format == Format.RNN){
|
||||||
int[] featuresShape = new int[] {1, wordVectorSize, Math.min(maxSentenceLength, tokens.size())};
|
int[] featuresShape = new int[] {1, wordVectorSize, Math.min(maxSentenceLength, tokens.size())};
|
||||||
INDArray features = Nd4j.create(featuresShape, (format == Format.CNN1D ? 'c' : 'f'));
|
INDArray features = Nd4j.create(featuresShape, (format == Format.CNN1D ? 'c' : 'f'));
|
||||||
|
|
|
@ -100,7 +100,15 @@ public class AggregatingSentenceIterator implements SentenceIterator {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Use {@link #sentencePreProcessor(SentencePreProcessor)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
public Builder addSentencePreProcessor(@NonNull SentencePreProcessor preProcessor) {
|
public Builder addSentencePreProcessor(@NonNull SentencePreProcessor preProcessor) {
|
||||||
|
return sentencePreProcessor(preProcessor);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder sentencePreProcessor(@NonNull SentencePreProcessor preProcessor) {
|
||||||
this.preProcessor = preProcessor;
|
this.preProcessor = preProcessor;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
|
@ -264,6 +264,21 @@ public class TestCnnSentenceDataSetIterator extends BaseDL4JTest {
|
||||||
assertEquals(expLabels, ds.getLabels());
|
assertEquals(expLabels, ds.getLabels());
|
||||||
assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
|
assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
|
||||||
assertNull(ds.getLabelsMaskArray());
|
assertNull(ds.getLabelsMaskArray());
|
||||||
|
|
||||||
|
|
||||||
|
//Sanity check on single sentence loading:
|
||||||
|
INDArray allKnownWords = dsi.loadSingleSentence("these balance");
|
||||||
|
INDArray withUnknown = dsi.loadSingleSentence("these NOVALID");
|
||||||
|
assertNotNull(allKnownWords);
|
||||||
|
assertNotNull(withUnknown);
|
||||||
|
|
||||||
|
try {
|
||||||
|
dsi.loadSingleSentence("NOVALID AlsoNotInVocab");
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage();
|
||||||
|
assertTrue(m, m.contains("RemoveWord") && m.contains("vocabulary"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -324,4 +339,56 @@ public class TestCnnSentenceDataSetIterator extends BaseDL4JTest {
|
||||||
assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
|
assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
|
||||||
assertNull(ds.getLabelsMaskArray());
|
assertNull(ds.getLabelsMaskArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnnSentenceDataSetIteratorUseUnknownVector() throws Exception {
|
||||||
|
|
||||||
|
WordVectors w2v = WordVectorSerializer
|
||||||
|
.readWord2VecModel(new ClassPathResource("word2vec/googleload/sample_vec.bin").getFile());
|
||||||
|
|
||||||
|
List<String> sentences = new ArrayList<>();
|
||||||
|
sentences.add("these balance Database model");
|
||||||
|
sentences.add("into same THISWORDDOESNTEXIST are");
|
||||||
|
//Last 2 sentences - no valid words
|
||||||
|
sentences.add("NOVALID WORDSHERE");
|
||||||
|
sentences.add("!!!");
|
||||||
|
|
||||||
|
List<String> labelsForSentences = Arrays.asList("Positive", "Negative", "Positive", "Negative");
|
||||||
|
|
||||||
|
|
||||||
|
LabeledSentenceProvider p = new CollectionLabeledSentenceProvider(sentences, labelsForSentences, null);
|
||||||
|
CnnSentenceDataSetIterator dsi = new CnnSentenceDataSetIterator.Builder(CnnSentenceDataSetIterator.Format.CNN1D)
|
||||||
|
.unknownWordHandling(CnnSentenceDataSetIterator.UnknownWordHandling.UseUnknownVector)
|
||||||
|
.sentenceProvider(p).wordVectors(w2v)
|
||||||
|
.useNormalizedWordVectors(true)
|
||||||
|
.maxSentenceLength(256).minibatchSize(4).sentencesAlongHeight(false).build();
|
||||||
|
|
||||||
|
assertTrue(dsi.hasNext());
|
||||||
|
DataSet ds = dsi.next();
|
||||||
|
|
||||||
|
assertFalse(dsi.hasNext());
|
||||||
|
|
||||||
|
INDArray f = ds.getFeatures();
|
||||||
|
assertEquals(4, f.size(0));
|
||||||
|
|
||||||
|
INDArray unknown = w2v.getWordVectorMatrix(w2v.getUNK());
|
||||||
|
if(unknown == null)
|
||||||
|
unknown = Nd4j.create(DataType.FLOAT, f.size(1));
|
||||||
|
|
||||||
|
assertEquals(unknown, f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(0)));
|
||||||
|
assertEquals(unknown, f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(1)));
|
||||||
|
assertEquals(unknown.like(), f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(3)));
|
||||||
|
|
||||||
|
assertEquals(unknown, f.get(NDArrayIndex.point(3), NDArrayIndex.all(), NDArrayIndex.point(0)));
|
||||||
|
assertEquals(unknown.like(), f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(1)));
|
||||||
|
|
||||||
|
//Sanity check on single sentence loading:
|
||||||
|
INDArray allKnownWords = dsi.loadSingleSentence("these balance");
|
||||||
|
INDArray withUnknown = dsi.loadSingleSentence("these NOVALID");
|
||||||
|
INDArray allUnknown = dsi.loadSingleSentence("NOVALID AlsoNotInVocab");
|
||||||
|
assertNotNull(allKnownWords);
|
||||||
|
assertNotNull(withUnknown);
|
||||||
|
assertNotNull(allUnknown);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue