Various fixes (#141)
* #8121 CnnSentenceDataSetIterator fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8120 CnnSentenceDataSetIterator.loadSingleSentence no words UX/exception improvement Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8122 AggregatingSentenceIterator builder - addSentencePreProcessor -> sentencePreProcessor Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8082 Arbiter - fix GridSearchCandidateGenerator search size issue Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
0adce9a4fa
commit
348d9c59f7
|
@ -19,6 +19,7 @@ package org.deeplearning4j.arbiter.optimize.api.adapter;
|
|||
import lombok.AllArgsConstructor;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -37,6 +38,8 @@ public abstract class ParameterSpaceAdapter<F, T> implements ParameterSpace<T> {
|
|||
|
||||
protected abstract ParameterSpace<F> underlying();
|
||||
|
||||
protected abstract String underlyingName();
|
||||
|
||||
|
||||
@Override
|
||||
public T getValue(double[] parameterValues) {
|
||||
|
@ -50,17 +53,21 @@ public abstract class ParameterSpaceAdapter<F, T> implements ParameterSpace<T> {
|
|||
|
||||
@Override
|
||||
public List<ParameterSpace> collectLeaves() {
|
||||
ParameterSpace p = underlying();
|
||||
if(p.isLeaf()){
|
||||
return Collections.singletonList(p);
|
||||
}
|
||||
return underlying().collectLeaves();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
||||
return underlying().getNestedSpaces();
|
||||
return Collections.singletonMap(underlyingName(), (ParameterSpace)underlying());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeaf() {
|
||||
return underlying().isLeaf();
|
||||
return false; //Underlying may be a leaf, however
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -17,10 +17,12 @@
|
|||
package org.deeplearning4j.arbiter.optimize.generator;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.math3.random.RandomAdaptor;
|
||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
|
||||
import org.deeplearning4j.arbiter.util.LeafUtils;
|
||||
|
@ -65,6 +67,7 @@ public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
|
|||
private final Mode mode;
|
||||
|
||||
private int[] numValuesPerParam;
|
||||
@Getter
|
||||
private int totalNumCandidates;
|
||||
private Queue<Integer> order;
|
||||
|
||||
|
@ -123,6 +126,8 @@ public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
|
|||
int max = ips.getMax();
|
||||
//Discretize, as some integer ranges are much too large to search (i.e., num. neural network units, between 100 and 1000)
|
||||
numValuesPerParam[i] = Math.min(max - min + 1, discretizationCount);
|
||||
} else if (ps instanceof FixedValue){
|
||||
numValuesPerParam[i] = 1;
|
||||
} else {
|
||||
numValuesPerParam[i] = discretizationCount;
|
||||
}
|
||||
|
|
|
@ -44,10 +44,7 @@ import org.nd4j.linalg.learning.config.IUpdater;
|
|||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* This is an abstract ParameterSpace for both MultiLayerNetworks (MultiLayerSpace) and ComputationGraph (ComputationGraphSpace)
|
||||
|
@ -212,18 +209,30 @@ public abstract class BaseNetworkSpace<T> extends AbstractParameterSpace<T> {
|
|||
@Override
|
||||
public List<ParameterSpace> collectLeaves() {
|
||||
Map<String, ParameterSpace> global = getNestedSpaces();
|
||||
List<ParameterSpace> list = new ArrayList<>();
|
||||
list.addAll(global.values());
|
||||
|
||||
//Note: Results on previous line does NOT include the LayerSpaces, therefore we need to add these manually...
|
||||
//This is because the type is a list, not a ParameterSpace
|
||||
LinkedList<ParameterSpace> stack = new LinkedList<>();
|
||||
stack.add(this);
|
||||
|
||||
for (LayerConf layerConf : layerSpaces) {
|
||||
LayerSpace ls = layerConf.getLayerSpace();
|
||||
list.addAll(ls.collectLeaves());
|
||||
stack.addAll(ls.collectLeaves());
|
||||
}
|
||||
|
||||
return list;
|
||||
List<ParameterSpace> out = new ArrayList<>();
|
||||
while (!stack.isEmpty()) {
|
||||
ParameterSpace next = stack.removeLast();
|
||||
if (next.isLeaf()) {
|
||||
out.add(next);
|
||||
} else {
|
||||
Map<String, ParameterSpace> m = next.getNestedSpaces();
|
||||
ParameterSpace[] arr = m.values().toArray(new ParameterSpace[m.size()]);
|
||||
for (int i = arr.length - 1; i >= 0; i--) {
|
||||
stack.add(arr[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -84,8 +84,10 @@ public class MultiLayerSpace extends BaseNetworkSpace<DL4JConfiguration> {
|
|||
List<ParameterSpace> allLeaves = collectLeaves();
|
||||
List<ParameterSpace> list = LeafUtils.getUniqueObjects(allLeaves);
|
||||
|
||||
for (ParameterSpace ps : list)
|
||||
for (ParameterSpace ps : list) {
|
||||
int n = ps.numParameters();
|
||||
numParameters += ps.numParameters();
|
||||
}
|
||||
|
||||
this.trainingWorkspaceMode = builder.trainingWorkspaceMode;
|
||||
this.inferenceWorkspaceMode = builder.inferenceWorkspaceMode;
|
||||
|
|
|
@ -50,4 +50,9 @@ public class ActivationParameterSpaceAdapter extends ParameterSpaceAdapter<Activ
|
|||
protected ParameterSpace<Activation> underlying() {
|
||||
return activation;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String underlyingName() {
|
||||
return "activation";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,4 +52,9 @@ public class LossFunctionParameterSpaceAdapter
|
|||
protected ParameterSpace<LossFunctions.LossFunction> underlying() {
|
||||
return lossFunction;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String underlyingName() {
|
||||
return "lossFunction";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,9 +52,11 @@ public class DropoutLayerSpace extends LayerSpace<DropoutLayer> {
|
|||
|
||||
@Override
|
||||
public List<ParameterSpace> collectLeaves() {
|
||||
return Collections.<ParameterSpace>singletonList(dropOut);
|
||||
return dropOut.collectLeaves();
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public boolean isLeaf() {
|
||||
return false;
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
|||
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
||||
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
|
||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
|
||||
|
@ -706,4 +707,63 @@ public class TestMultiLayerSpace {
|
|||
|
||||
MultiLayerConfiguration conf = mls.getValue(new double[nParams]).getMultiLayerConfiguration();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testIssue8082(){
|
||||
ParameterSpace<Double> learningRateHyperparam = new DiscreteParameterSpace<>(0.003, 0.005, 0.01, 0.05);
|
||||
ParameterSpace<Integer> layerSizeHyperparam1 = new DiscreteParameterSpace<>(32, 64, 96, 128);
|
||||
ParameterSpace<Integer> layerSizeHyperparam2 = new DiscreteParameterSpace<>(32, 64, 96, 128);
|
||||
ParameterSpace<Double> dropoutHyperparam = new DiscreteParameterSpace<>(0.8, 0.9);
|
||||
|
||||
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
||||
.updater(new AdamSpace(learningRateHyperparam))
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.l2(0.0001)
|
||||
.addLayer(new DenseLayerSpace.Builder()
|
||||
.nIn(10)
|
||||
.nOut(layerSizeHyperparam1)
|
||||
.build())
|
||||
.addLayer(new BatchNormalizationSpace.Builder()
|
||||
.nOut(layerSizeHyperparam1)
|
||||
.activation(Activation.RELU)
|
||||
.build())
|
||||
.addLayer(new DropoutLayerSpace.Builder()
|
||||
.dropOut(dropoutHyperparam)
|
||||
.build())
|
||||
.addLayer(new DenseLayerSpace.Builder()
|
||||
.nOut(layerSizeHyperparam2)
|
||||
.build())
|
||||
.addLayer(new BatchNormalizationSpace.Builder()
|
||||
.nOut(layerSizeHyperparam2)
|
||||
.activation(Activation.RELU)
|
||||
.build())
|
||||
.addLayer(new DropoutLayerSpace.Builder()
|
||||
.dropOut(dropoutHyperparam)
|
||||
.build())
|
||||
.addLayer(new OutputLayerSpace.Builder()
|
||||
.nOut(10)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunction.MCXENT)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
assertEquals(4, mls.getNumParameters());
|
||||
|
||||
for( int discreteCount : new int[]{1, 5}) {
|
||||
GridSearchCandidateGenerator generator = new GridSearchCandidateGenerator(mls, discreteCount, GridSearchCandidateGenerator.Mode.Sequential, null);
|
||||
|
||||
int expCandidates = 4 * 4 * 4 * 2;
|
||||
assertEquals(expCandidates, generator.getTotalNumCandidates());
|
||||
|
||||
int count = 0;
|
||||
while (generator.hasMoreCandidates()) {
|
||||
generator.getCandidate();
|
||||
count++;
|
||||
}
|
||||
|
||||
|
||||
assertEquals(expCandidates, count);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -117,18 +117,22 @@ public class CnnSentenceDataSetIterator implements DataSetIterator {
|
|||
List<String> sortedLabels = new ArrayList<>(this.sentenceProvider.allLabels());
|
||||
Collections.sort(sortedLabels);
|
||||
|
||||
this.wordVectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;
|
||||
|
||||
for (String s : sortedLabels) {
|
||||
this.labelClassMap.put(s, count++);
|
||||
}
|
||||
if (unknownWordHandling == UnknownWordHandling.UseUnknownVector) {
|
||||
if (useNormalizedWordVectors) {
|
||||
wordVectors.getWordVectorMatrixNormalized(wordVectors.getUNK());
|
||||
unknown = wordVectors.getWordVectorMatrixNormalized(wordVectors.getUNK());
|
||||
} else {
|
||||
wordVectors.getWordVectorMatrix(wordVectors.getUNK());
|
||||
unknown = wordVectors.getWordVectorMatrix(wordVectors.getUNK());
|
||||
}
|
||||
|
||||
if(unknown == null){
|
||||
unknown = wordVectors.getWordVectorMatrix(wordVectors.vocab().wordAtIndex(0)).like();
|
||||
}
|
||||
}
|
||||
|
||||
this.wordVectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -136,6 +140,9 @@ public class CnnSentenceDataSetIterator implements DataSetIterator {
|
|||
*/
|
||||
public INDArray loadSingleSentence(String sentence) {
|
||||
List<String> tokens = tokenizeSentence(sentence);
|
||||
if(tokens.isEmpty())
|
||||
throw new IllegalStateException("No tokens available for input sentence - empty string or no words in vocabulary with RemoveWord unknown handling? Sentence = \"" +
|
||||
sentence + "\"");
|
||||
if(format == Format.CNN1D || format == Format.RNN){
|
||||
int[] featuresShape = new int[] {1, wordVectorSize, Math.min(maxSentenceLength, tokens.size())};
|
||||
INDArray features = Nd4j.create(featuresShape, (format == Format.CNN1D ? 'c' : 'f'));
|
||||
|
|
|
@ -100,7 +100,15 @@ public class AggregatingSentenceIterator implements SentenceIterator {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #sentencePreProcessor(SentencePreProcessor)}
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder addSentencePreProcessor(@NonNull SentencePreProcessor preProcessor) {
|
||||
return sentencePreProcessor(preProcessor);
|
||||
}
|
||||
|
||||
public Builder sentencePreProcessor(@NonNull SentencePreProcessor preProcessor) {
|
||||
this.preProcessor = preProcessor;
|
||||
return this;
|
||||
}
|
||||
|
|
|
@ -264,6 +264,21 @@ public class TestCnnSentenceDataSetIterator extends BaseDL4JTest {
|
|||
assertEquals(expLabels, ds.getLabels());
|
||||
assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
|
||||
assertNull(ds.getLabelsMaskArray());
|
||||
|
||||
|
||||
//Sanity check on single sentence loading:
|
||||
INDArray allKnownWords = dsi.loadSingleSentence("these balance");
|
||||
INDArray withUnknown = dsi.loadSingleSentence("these NOVALID");
|
||||
assertNotNull(allKnownWords);
|
||||
assertNotNull(withUnknown);
|
||||
|
||||
try {
|
||||
dsi.loadSingleSentence("NOVALID AlsoNotInVocab");
|
||||
fail("Expected exception");
|
||||
} catch (Throwable t){
|
||||
String m = t.getMessage();
|
||||
assertTrue(m, m.contains("RemoveWord") && m.contains("vocabulary"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -324,4 +339,56 @@ public class TestCnnSentenceDataSetIterator extends BaseDL4JTest {
|
|||
assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
|
||||
assertNull(ds.getLabelsMaskArray());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testCnnSentenceDataSetIteratorUseUnknownVector() throws Exception {
|
||||
|
||||
WordVectors w2v = WordVectorSerializer
|
||||
.readWord2VecModel(new ClassPathResource("word2vec/googleload/sample_vec.bin").getFile());
|
||||
|
||||
List<String> sentences = new ArrayList<>();
|
||||
sentences.add("these balance Database model");
|
||||
sentences.add("into same THISWORDDOESNTEXIST are");
|
||||
//Last 2 sentences - no valid words
|
||||
sentences.add("NOVALID WORDSHERE");
|
||||
sentences.add("!!!");
|
||||
|
||||
List<String> labelsForSentences = Arrays.asList("Positive", "Negative", "Positive", "Negative");
|
||||
|
||||
|
||||
LabeledSentenceProvider p = new CollectionLabeledSentenceProvider(sentences, labelsForSentences, null);
|
||||
CnnSentenceDataSetIterator dsi = new CnnSentenceDataSetIterator.Builder(CnnSentenceDataSetIterator.Format.CNN1D)
|
||||
.unknownWordHandling(CnnSentenceDataSetIterator.UnknownWordHandling.UseUnknownVector)
|
||||
.sentenceProvider(p).wordVectors(w2v)
|
||||
.useNormalizedWordVectors(true)
|
||||
.maxSentenceLength(256).minibatchSize(4).sentencesAlongHeight(false).build();
|
||||
|
||||
assertTrue(dsi.hasNext());
|
||||
DataSet ds = dsi.next();
|
||||
|
||||
assertFalse(dsi.hasNext());
|
||||
|
||||
INDArray f = ds.getFeatures();
|
||||
assertEquals(4, f.size(0));
|
||||
|
||||
INDArray unknown = w2v.getWordVectorMatrix(w2v.getUNK());
|
||||
if(unknown == null)
|
||||
unknown = Nd4j.create(DataType.FLOAT, f.size(1));
|
||||
|
||||
assertEquals(unknown, f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(0)));
|
||||
assertEquals(unknown, f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(1)));
|
||||
assertEquals(unknown.like(), f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(3)));
|
||||
|
||||
assertEquals(unknown, f.get(NDArrayIndex.point(3), NDArrayIndex.all(), NDArrayIndex.point(0)));
|
||||
assertEquals(unknown.like(), f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(1)));
|
||||
|
||||
//Sanity check on single sentence loading:
|
||||
INDArray allKnownWords = dsi.loadSingleSentence("these balance");
|
||||
INDArray withUnknown = dsi.loadSingleSentence("these NOVALID");
|
||||
INDArray allUnknown = dsi.loadSingleSentence("NOVALID AlsoNotInVocab");
|
||||
assertNotNull(allKnownWords);
|
||||
assertNotNull(withUnknown);
|
||||
assertNotNull(allUnknown);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue