Various fixes (#141)

* #8121 CnnSentenceDataSetIterator fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8120 CnnSentenceDataSetIterator.loadSingleSentence no words UX/exception improvement

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8122 AggregatingSentenceIterator builder - addSentencePreProcessor -> sentencePreProcessor

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8082 Arbiter - fix GridSearchCandidateGenerator search size issue

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-21 23:47:24 +10:00 committed by GitHub
parent 0adce9a4fa
commit 348d9c59f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 194 additions and 17 deletions

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.arbiter.optimize.api.adapter;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -37,6 +38,8 @@ public abstract class ParameterSpaceAdapter<F, T> implements ParameterSpace<T> {
protected abstract ParameterSpace<F> underlying(); protected abstract ParameterSpace<F> underlying();
protected abstract String underlyingName();
@Override @Override
public T getValue(double[] parameterValues) { public T getValue(double[] parameterValues) {
@ -50,17 +53,21 @@ public abstract class ParameterSpaceAdapter<F, T> implements ParameterSpace<T> {
@Override @Override
public List<ParameterSpace> collectLeaves() { public List<ParameterSpace> collectLeaves() {
ParameterSpace p = underlying();
if(p.isLeaf()){
return Collections.singletonList(p);
}
return underlying().collectLeaves(); return underlying().collectLeaves();
} }
@Override @Override
public Map<String, ParameterSpace> getNestedSpaces() { public Map<String, ParameterSpace> getNestedSpaces() {
return underlying().getNestedSpaces(); return Collections.singletonMap(underlyingName(), (ParameterSpace)underlying());
} }
@Override @Override
public boolean isLeaf() { public boolean isLeaf() {
return underlying().isLeaf(); return false; //Underlying may be a leaf, however
} }
@Override @Override

View File

@ -17,10 +17,12 @@
package org.deeplearning4j.arbiter.optimize.generator; package org.deeplearning4j.arbiter.optimize.generator;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.math3.random.RandomAdaptor; import org.apache.commons.math3.random.RandomAdaptor;
import org.deeplearning4j.arbiter.optimize.api.Candidate; import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
import org.deeplearning4j.arbiter.util.LeafUtils; import org.deeplearning4j.arbiter.util.LeafUtils;
@ -65,6 +67,7 @@ public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
private final Mode mode; private final Mode mode;
private int[] numValuesPerParam; private int[] numValuesPerParam;
@Getter
private int totalNumCandidates; private int totalNumCandidates;
private Queue<Integer> order; private Queue<Integer> order;
@ -123,6 +126,8 @@ public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
int max = ips.getMax(); int max = ips.getMax();
//Discretize, as some integer ranges are much too large to search (i.e., num. neural network units, between 100 and 1000) //Discretize, as some integer ranges are much too large to search (i.e., num. neural network units, between 100 and 1000)
numValuesPerParam[i] = Math.min(max - min + 1, discretizationCount); numValuesPerParam[i] = Math.min(max - min + 1, discretizationCount);
} else if (ps instanceof FixedValue){
numValuesPerParam[i] = 1;
} else { } else {
numValuesPerParam[i] = discretizationCount; numValuesPerParam[i] = discretizationCount;
} }

View File

@ -44,10 +44,7 @@ import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.core.JsonProcessingException;
import java.util.ArrayList; import java.util.*;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
/** /**
* This is an abstract ParameterSpace for both MultiLayerNetworks (MultiLayerSpace) and ComputationGraph (ComputationGraphSpace) * This is an abstract ParameterSpace for both MultiLayerNetworks (MultiLayerSpace) and ComputationGraph (ComputationGraphSpace)
@ -212,18 +209,30 @@ public abstract class BaseNetworkSpace<T> extends AbstractParameterSpace<T> {
@Override @Override
public List<ParameterSpace> collectLeaves() { public List<ParameterSpace> collectLeaves() {
Map<String, ParameterSpace> global = getNestedSpaces(); Map<String, ParameterSpace> global = getNestedSpaces();
List<ParameterSpace> list = new ArrayList<>();
list.addAll(global.values());
//Note: Results on previous line does NOT include the LayerSpaces, therefore we need to add these manually... //Note: Results on previous line does NOT include the LayerSpaces, therefore we need to add these manually...
//This is because the type is a list, not a ParameterSpace //This is because the type is a list, not a ParameterSpace
LinkedList<ParameterSpace> stack = new LinkedList<>();
stack.add(this);
for (LayerConf layerConf : layerSpaces) { for (LayerConf layerConf : layerSpaces) {
LayerSpace ls = layerConf.getLayerSpace(); LayerSpace ls = layerConf.getLayerSpace();
list.addAll(ls.collectLeaves()); stack.addAll(ls.collectLeaves());
} }
return list; List<ParameterSpace> out = new ArrayList<>();
while (!stack.isEmpty()) {
ParameterSpace next = stack.removeLast();
if (next.isLeaf()) {
out.add(next);
} else {
Map<String, ParameterSpace> m = next.getNestedSpaces();
ParameterSpace[] arr = m.values().toArray(new ParameterSpace[m.size()]);
for (int i = arr.length - 1; i >= 0; i--) {
stack.add(arr[i]);
}
}
}
return out;
} }

View File

@ -84,8 +84,10 @@ public class MultiLayerSpace extends BaseNetworkSpace<DL4JConfiguration> {
List<ParameterSpace> allLeaves = collectLeaves(); List<ParameterSpace> allLeaves = collectLeaves();
List<ParameterSpace> list = LeafUtils.getUniqueObjects(allLeaves); List<ParameterSpace> list = LeafUtils.getUniqueObjects(allLeaves);
for (ParameterSpace ps : list) for (ParameterSpace ps : list) {
int n = ps.numParameters();
numParameters += ps.numParameters(); numParameters += ps.numParameters();
}
this.trainingWorkspaceMode = builder.trainingWorkspaceMode; this.trainingWorkspaceMode = builder.trainingWorkspaceMode;
this.inferenceWorkspaceMode = builder.inferenceWorkspaceMode; this.inferenceWorkspaceMode = builder.inferenceWorkspaceMode;

View File

@ -50,4 +50,9 @@ public class ActivationParameterSpaceAdapter extends ParameterSpaceAdapter<Activ
protected ParameterSpace<Activation> underlying() { protected ParameterSpace<Activation> underlying() {
return activation; return activation;
} }
@Override
protected String underlyingName() {
return "activation";
}
} }

View File

@ -52,4 +52,9 @@ public class LossFunctionParameterSpaceAdapter
protected ParameterSpace<LossFunctions.LossFunction> underlying() { protected ParameterSpace<LossFunctions.LossFunction> underlying() {
return lossFunction; return lossFunction;
} }
@Override
protected String underlyingName() {
return "lossFunction";
}
} }

View File

@ -52,9 +52,11 @@ public class DropoutLayerSpace extends LayerSpace<DropoutLayer> {
@Override @Override
public List<ParameterSpace> collectLeaves() { public List<ParameterSpace> collectLeaves() {
return Collections.<ParameterSpace>singletonList(dropOut); return dropOut.collectLeaves();
} }
@Override @Override
public boolean isLeaf() { public boolean isLeaf() {
return false; return false;

View File

@ -31,6 +31,7 @@ import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition; import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
@ -706,4 +707,63 @@ public class TestMultiLayerSpace {
MultiLayerConfiguration conf = mls.getValue(new double[nParams]).getMultiLayerConfiguration(); MultiLayerConfiguration conf = mls.getValue(new double[nParams]).getMultiLayerConfiguration();
} }
@Test
public void testIssue8082(){
ParameterSpace<Double> learningRateHyperparam = new DiscreteParameterSpace<>(0.003, 0.005, 0.01, 0.05);
ParameterSpace<Integer> layerSizeHyperparam1 = new DiscreteParameterSpace<>(32, 64, 96, 128);
ParameterSpace<Integer> layerSizeHyperparam2 = new DiscreteParameterSpace<>(32, 64, 96, 128);
ParameterSpace<Double> dropoutHyperparam = new DiscreteParameterSpace<>(0.8, 0.9);
MultiLayerSpace mls = new MultiLayerSpace.Builder()
.updater(new AdamSpace(learningRateHyperparam))
.weightInit(WeightInit.XAVIER)
.l2(0.0001)
.addLayer(new DenseLayerSpace.Builder()
.nIn(10)
.nOut(layerSizeHyperparam1)
.build())
.addLayer(new BatchNormalizationSpace.Builder()
.nOut(layerSizeHyperparam1)
.activation(Activation.RELU)
.build())
.addLayer(new DropoutLayerSpace.Builder()
.dropOut(dropoutHyperparam)
.build())
.addLayer(new DenseLayerSpace.Builder()
.nOut(layerSizeHyperparam2)
.build())
.addLayer(new BatchNormalizationSpace.Builder()
.nOut(layerSizeHyperparam2)
.activation(Activation.RELU)
.build())
.addLayer(new DropoutLayerSpace.Builder()
.dropOut(dropoutHyperparam)
.build())
.addLayer(new OutputLayerSpace.Builder()
.nOut(10)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunction.MCXENT)
.build())
.build();
assertEquals(4, mls.getNumParameters());
for( int discreteCount : new int[]{1, 5}) {
GridSearchCandidateGenerator generator = new GridSearchCandidateGenerator(mls, discreteCount, GridSearchCandidateGenerator.Mode.Sequential, null);
int expCandidates = 4 * 4 * 4 * 2;
assertEquals(expCandidates, generator.getTotalNumCandidates());
int count = 0;
while (generator.hasMoreCandidates()) {
generator.getCandidate();
count++;
}
assertEquals(expCandidates, count);
}
}
} }

View File

@ -117,18 +117,22 @@ public class CnnSentenceDataSetIterator implements DataSetIterator {
List<String> sortedLabels = new ArrayList<>(this.sentenceProvider.allLabels()); List<String> sortedLabels = new ArrayList<>(this.sentenceProvider.allLabels());
Collections.sort(sortedLabels); Collections.sort(sortedLabels);
this.wordVectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;
for (String s : sortedLabels) { for (String s : sortedLabels) {
this.labelClassMap.put(s, count++); this.labelClassMap.put(s, count++);
} }
if (unknownWordHandling == UnknownWordHandling.UseUnknownVector) { if (unknownWordHandling == UnknownWordHandling.UseUnknownVector) {
if (useNormalizedWordVectors) { if (useNormalizedWordVectors) {
wordVectors.getWordVectorMatrixNormalized(wordVectors.getUNK()); unknown = wordVectors.getWordVectorMatrixNormalized(wordVectors.getUNK());
} else { } else {
wordVectors.getWordVectorMatrix(wordVectors.getUNK()); unknown = wordVectors.getWordVectorMatrix(wordVectors.getUNK());
}
} }
this.wordVectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length; if(unknown == null){
unknown = wordVectors.getWordVectorMatrix(wordVectors.vocab().wordAtIndex(0)).like();
}
}
} }
/** /**
@ -136,6 +140,9 @@ public class CnnSentenceDataSetIterator implements DataSetIterator {
*/ */
public INDArray loadSingleSentence(String sentence) { public INDArray loadSingleSentence(String sentence) {
List<String> tokens = tokenizeSentence(sentence); List<String> tokens = tokenizeSentence(sentence);
if(tokens.isEmpty())
throw new IllegalStateException("No tokens available for input sentence - empty string or no words in vocabulary with RemoveWord unknown handling? Sentence = \"" +
sentence + "\"");
if(format == Format.CNN1D || format == Format.RNN){ if(format == Format.CNN1D || format == Format.RNN){
int[] featuresShape = new int[] {1, wordVectorSize, Math.min(maxSentenceLength, tokens.size())}; int[] featuresShape = new int[] {1, wordVectorSize, Math.min(maxSentenceLength, tokens.size())};
INDArray features = Nd4j.create(featuresShape, (format == Format.CNN1D ? 'c' : 'f')); INDArray features = Nd4j.create(featuresShape, (format == Format.CNN1D ? 'c' : 'f'));

View File

@ -100,7 +100,15 @@ public class AggregatingSentenceIterator implements SentenceIterator {
return this; return this;
} }
/**
* @deprecated Use {@link #sentencePreProcessor(SentencePreProcessor)}
*/
@Deprecated
public Builder addSentencePreProcessor(@NonNull SentencePreProcessor preProcessor) { public Builder addSentencePreProcessor(@NonNull SentencePreProcessor preProcessor) {
return sentencePreProcessor(preProcessor);
}
public Builder sentencePreProcessor(@NonNull SentencePreProcessor preProcessor) {
this.preProcessor = preProcessor; this.preProcessor = preProcessor;
return this; return this;
} }

View File

@ -264,6 +264,21 @@ public class TestCnnSentenceDataSetIterator extends BaseDL4JTest {
assertEquals(expLabels, ds.getLabels()); assertEquals(expLabels, ds.getLabels());
assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray()); assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
assertNull(ds.getLabelsMaskArray()); assertNull(ds.getLabelsMaskArray());
//Sanity check on single sentence loading:
INDArray allKnownWords = dsi.loadSingleSentence("these balance");
INDArray withUnknown = dsi.loadSingleSentence("these NOVALID");
assertNotNull(allKnownWords);
assertNotNull(withUnknown);
try {
dsi.loadSingleSentence("NOVALID AlsoNotInVocab");
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage();
assertTrue(m, m.contains("RemoveWord") && m.contains("vocabulary"));
}
} }
@Test @Test
@ -324,4 +339,56 @@ public class TestCnnSentenceDataSetIterator extends BaseDL4JTest {
assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray()); assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
assertNull(ds.getLabelsMaskArray()); assertNull(ds.getLabelsMaskArray());
} }
@Test
public void testCnnSentenceDataSetIteratorUseUnknownVector() throws Exception {
WordVectors w2v = WordVectorSerializer
.readWord2VecModel(new ClassPathResource("word2vec/googleload/sample_vec.bin").getFile());
List<String> sentences = new ArrayList<>();
sentences.add("these balance Database model");
sentences.add("into same THISWORDDOESNTEXIST are");
//Last 2 sentences - no valid words
sentences.add("NOVALID WORDSHERE");
sentences.add("!!!");
List<String> labelsForSentences = Arrays.asList("Positive", "Negative", "Positive", "Negative");
LabeledSentenceProvider p = new CollectionLabeledSentenceProvider(sentences, labelsForSentences, null);
CnnSentenceDataSetIterator dsi = new CnnSentenceDataSetIterator.Builder(CnnSentenceDataSetIterator.Format.CNN1D)
.unknownWordHandling(CnnSentenceDataSetIterator.UnknownWordHandling.UseUnknownVector)
.sentenceProvider(p).wordVectors(w2v)
.useNormalizedWordVectors(true)
.maxSentenceLength(256).minibatchSize(4).sentencesAlongHeight(false).build();
assertTrue(dsi.hasNext());
DataSet ds = dsi.next();
assertFalse(dsi.hasNext());
INDArray f = ds.getFeatures();
assertEquals(4, f.size(0));
INDArray unknown = w2v.getWordVectorMatrix(w2v.getUNK());
if(unknown == null)
unknown = Nd4j.create(DataType.FLOAT, f.size(1));
assertEquals(unknown, f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(0)));
assertEquals(unknown, f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(1)));
assertEquals(unknown.like(), f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(3)));
assertEquals(unknown, f.get(NDArrayIndex.point(3), NDArrayIndex.all(), NDArrayIndex.point(0)));
assertEquals(unknown.like(), f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(1)));
//Sanity check on single sentence loading:
INDArray allKnownWords = dsi.loadSingleSentence("these balance");
INDArray withUnknown = dsi.loadSingleSentence("these NOVALID");
INDArray allUnknown = dsi.loadSingleSentence("NOVALID AlsoNotInVocab");
assertNotNull(allKnownWords);
assertNotNull(withUnknown);
assertNotNull(allUnknown);
}
} }