BertIterator sentence pair support for supervised training (#108)

* bert iterator sentence pair handling

Signed-off-by: eraly <susan.eraly@gmail.com>

* bert iterator sentence pair handling -seg

Signed-off-by: eraly <susan.eraly@gmail.com>

* bert iterator sentence pair handling tests

Signed-off-by: eraly <susan.eraly@gmail.com>

* test with pairs long done

Signed-off-by: eraly <susan.eraly@gmail.com>

* more tests with bert iter sent pairs done

Signed-off-by: eraly <susan.eraly@gmail.com>

* fixed copyright, formatting

Signed-off-by: eraly <susan.eraly@gmail.com>

* bert iterator - added featurizer for sentence pair inference

Signed-off-by: eraly <susan.eraly@gmail.com>

* bert iterator - finished tests

Signed-off-by: eraly <susan.eraly@gmail.com>

* bert iterator - finished tests, polish

Signed-off-by: eraly <susan.eraly@gmail.com>

* collection labeled sentence provider

Signed-off-by: eraly <susan.eraly@gmail.com>

* lombok fix for pojo class

Signed-off-by: eraly <susan.eraly@gmail.com>

* java doc misc clean up

Signed-off-by: eraly <susan.eraly@gmail.com>

* Private access modifiers

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-12-04 23:35:38 +11:00 committed by GitHub
parent 9cc8803b8d
commit 91de96588c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 768 additions and 72 deletions

View File

@ -34,6 +34,7 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -85,10 +86,20 @@ import java.util.Map;
* <pre> * <pre>
* {@code * {@code
* BertIterator b; * BertIterator b;
* Pair<INDArray[],INDArray[]> featuresAndMask;
* INDArray[] features;
* INDArray[] featureMasks;
*
* //With sentences
* List<String> forInference; * List<String> forInference;
* Pair<INDArray[],INDArray[]> featuresAndMask = b.featurizeSentences(forInference); * featuresAndMask = b.featurizeSentences(forInference);
* INDArray[] features = featuresAndMask.getFirst(); *
* INDArray[] featureMasks = featuresAndMask.getSecond(); * //OR with sentence pairs
* List<Pair<String, String>> forInferencePair};
* featuresAndMask = b.featurizeSentencePairs(forInference);
*
* features = featuresAndMask.getFirst();
* featureMasks = featuresAndMask.getSecond();
* } * }
* </pre> * </pre>
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br> * This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br>
@ -135,6 +146,7 @@ public class BertIterator implements MultiDataSetIterator {
@Setter @Setter
protected MultiDataSetPreProcessor preProcessor; protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null; protected LabeledSentenceProvider sentenceProvider = null;
protected LabeledPairSentenceProvider sentencePairProvider = null;
protected LengthHandling lengthHandling; protected LengthHandling lengthHandling;
protected FeatureArrays featureArrays; protected FeatureArrays featureArrays;
protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects? protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects?
@ -142,6 +154,7 @@ public class BertIterator implements MultiDataSetIterator {
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null; protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
protected String maskToken; protected String maskToken;
protected String prependToken; protected String prependToken;
protected String appendToken;
protected List<String> vocabKeysAsList; protected List<String> vocabKeysAsList;
@ -154,6 +167,7 @@ public class BertIterator implements MultiDataSetIterator {
this.padMinibatches = b.padMinibatches; this.padMinibatches = b.padMinibatches;
this.preProcessor = b.preProcessor; this.preProcessor = b.preProcessor;
this.sentenceProvider = b.sentenceProvider; this.sentenceProvider = b.sentenceProvider;
this.sentencePairProvider = b.sentencePairProvider;
this.lengthHandling = b.lengthHandling; this.lengthHandling = b.lengthHandling;
this.featureArrays = b.featureArrays; this.featureArrays = b.featureArrays;
this.vocabMap = b.vocabMap; this.vocabMap = b.vocabMap;
@ -161,11 +175,14 @@ public class BertIterator implements MultiDataSetIterator {
this.unsupervisedLabelFormat = b.unsupervisedLabelFormat; this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
this.maskToken = b.maskToken; this.maskToken = b.maskToken;
this.prependToken = b.prependToken; this.prependToken = b.prependToken;
this.appendToken = b.appendToken;
} }
@Override @Override
public boolean hasNext() { public boolean hasNext() {
return sentenceProvider.hasNext(); if (sentenceProvider != null)
return sentenceProvider.hasNext();
return sentencePairProvider.hasNext();
} }
@Override @Override
@ -181,29 +198,38 @@ public class BertIterator implements MultiDataSetIterator {
@Override @Override
public MultiDataSet next(int num) { public MultiDataSet next(int num) {
Preconditions.checkState(hasNext(), "No next element available"); Preconditions.checkState(hasNext(), "No next element available");
List<Pair<List<String>, String>> tokensAndLabelList;
List<Pair<String, String>> list = new ArrayList<>(num);
int mbSize = 0; int mbSize = 0;
int outLength;
long[] segIdOnesFrom = null;
if (sentenceProvider != null) { if (sentenceProvider != null) {
List<Pair<String, String>> list = new ArrayList<>(num);
while (sentenceProvider.hasNext() && mbSize++ < num) { while (sentenceProvider.hasNext() && mbSize++ < num) {
list.add(sentenceProvider.nextSentence()); list.add(sentenceProvider.nextSentence());
} }
SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(list);
tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
outLength = sentenceListProcessed.getMaxL();
} else if (sentencePairProvider != null) {
List<Triple<String, String, String>> listPairs = new ArrayList<>(num);
while (sentencePairProvider.hasNext() && mbSize++ < num) {
listPairs.add(sentencePairProvider.nextSentencePair());
}
SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(listPairs);
tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
outLength = sentencePairListProcessed.getMaxL();
segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
} else { } else {
//TODO - other types of iterators... //TODO - other types of iterators...
throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented"); throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
} }
Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list);
List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
int outLength = outLTokenizedSentencesPair.getLeft();
Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength);
INDArray[] featureArray = featuresAndMaskArraysPair.getFirst(); INDArray[] featureArray = featuresAndMaskArraysPair.getFirst();
INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond(); INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond();
Pair<INDArray[], INDArray[]> labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength); Pair<INDArray[], INDArray[]> labelsAndMaskArraysPair = convertMiniBatchLabels(tokensAndLabelList, featureArray, outLength);
INDArray[] labelArray = labelsAndMaskArraysPair.getFirst(); INDArray[] labelArray = labelsAndMaskArraysPair.getFirst();
INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond(); INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond();
@ -224,32 +250,59 @@ public class BertIterator implements MultiDataSetIterator {
public Pair<INDArray[], INDArray[]> featurizeSentences(List<String> listOnlySentences) { public Pair<INDArray[], INDArray[]> featurizeSentences(List<String> listOnlySentences) {
List<Pair<String, String>> sentencesWithNullLabel = addDummyLabel(listOnlySentences); List<Pair<String, String>> sentencesWithNullLabel = addDummyLabel(listOnlySentences);
SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(sentencesWithNullLabel);
List<Pair<List<String>, String>> tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
int outLength = sentenceListProcessed.getMaxL();
Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel);
List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
int outLength = outLTokenizedSentencesPair.getLeft();
Pair<INDArray[], INDArray[]> featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength);
if (preProcessor != null) { if (preProcessor != null) {
Pair<INDArray[], INDArray[]> featureFeatureMasks = convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null); MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null);
preProcessor.preProcess(dummyMDS); preProcessor.preProcess(dummyMDS);
return new Pair<INDArray[],INDArray[]>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays()); return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
} }
return convertMiniBatchFeatures(tokenizedSentences, outLength); return convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
} }
private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> tokenizedSentences, int outLength) { /**
int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size(); * For use during inference. Will convert a given pair of a list of sentences to features and feature masks as appropriate.
*
* @param listOnlySentencePairs
* @return Pair of INDArrays[], first element is feature arrays and the second is the masks array
*/
public Pair<INDArray[], INDArray[]> featurizeSentencePairs(List<Pair<String, String>> listOnlySentencePairs) {
Preconditions.checkState(sentencePairProvider != null, "The featurizeSentencePairs method is meant for inference with sentence pairs. Use only when the sentence pair provider is set (i.e not null).");
List<Triple<String, String, String>> sentencePairsWithNullLabel = addDummyLabelForPairs(listOnlySentencePairs);
SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(sentencePairsWithNullLabel);
List<Pair<List<String>, String>> tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
int outLength = sentencePairListProcessed.getMaxL();
long[] segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
if (preProcessor != null) {
Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featuresAndMaskArraysPair.getFirst(), null, featuresAndMaskArraysPair.getSecond(), null);
preProcessor.preProcess(dummyMDS);
return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
}
return convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
}
private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> tokensAndLabelList, int outLength, long[] segIdOnesFrom) {
int mbPadded = padMinibatches ? minibatchSize : tokensAndLabelList.size();
int[][] outIdxs = new int[mbPadded][outLength]; int[][] outIdxs = new int[mbPadded][outLength];
int[][] outMask = new int[mbPadded][outLength]; int[][] outMask = new int[mbPadded][outLength];
for (int i = 0; i < tokenizedSentences.size(); i++) { int[][] outSegmentId = null;
Pair<List<String>, String> p = tokenizedSentences.get(i); if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID)
outSegmentId = new int[mbPadded][outLength];
for (int i = 0; i < tokensAndLabelList.size(); i++) {
Pair<List<String>, String> p = tokensAndLabelList.get(i);
List<String> t = p.getFirst(); List<String> t = p.getFirst();
for (int j = 0; j < outLength && j < t.size(); j++) { for (int j = 0; j < outLength && j < t.size(); j++) {
Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j)); Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j));
int idx = vocabMap.get(t.get(j)); int idx = vocabMap.get(t.get(j));
outIdxs[i][j] = idx; outIdxs[i][j] = idx;
outMask[i][j] = 1; outMask[i][j] = 1;
if (segIdOnesFrom != null && j >= segIdOnesFrom[i])
outSegmentId[i][j] = 1;
} }
} }
@ -260,8 +313,7 @@ public class BertIterator implements MultiDataSetIterator {
INDArray[] f; INDArray[] f;
INDArray[] fm; INDArray[] fm;
if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) { if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
//For now: always segment index 0 (only single s sequence input supported) outSegmentIdArr = Nd4j.createFromArray(outSegmentId);
outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength);
f = new INDArray[]{outIdxsArr, outSegmentIdArr}; f = new INDArray[]{outIdxsArr, outSegmentIdArr};
fm = new INDArray[]{outMaskArr, null}; fm = new INDArray[]{outMaskArr, null};
} else { } else {
@ -271,16 +323,15 @@ public class BertIterator implements MultiDataSetIterator {
return new Pair<>(f, fm); return new Pair<>(f, fm);
} }
private Pair<Integer, List<Pair<List<String>, String>>> tokenizeMiniBatch(List<Pair<String, String>> list) { private SentenceListProcessed tokenizeMiniBatch(List<Pair<String, String>> list) {
//Get and tokenize the sentences for this minibatch //Get and tokenize the sentences for this minibatch
List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(list.size()); SentenceListProcessed sentenceListProcessed = new SentenceListProcessed(list.size());
int longestSeq = -1; int longestSeq = -1;
for (Pair<String, String> p : list) { for (Pair<String, String> p : list) {
List<String> tokens = tokenizeSentence(p.getFirst()); List<String> tokens = tokenizeSentence(p.getFirst());
tokenizedSentences.add(new Pair<>(tokens, p.getSecond())); sentenceListProcessed.addProcessedToList(new Pair<>(tokens, p.getSecond()));
longestSeq = Math.max(longestSeq, tokens.size()); longestSeq = Math.max(longestSeq, tokens.size());
} }
//Determine output array length... //Determine output array length...
int outLength; int outLength;
switch (lengthHandling) { switch (lengthHandling) {
@ -296,7 +347,52 @@ public class BertIterator implements MultiDataSetIterator {
default: default:
throw new RuntimeException("Not implemented length handling mode: " + lengthHandling); throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
} }
return new Pair<>(outLength, tokenizedSentences); sentenceListProcessed.setMaxL(outLength);
return sentenceListProcessed;
}
private SentencePairListProcessed tokenizePairsMiniBatch(List<Triple<String, String, String>> listPairs) {
SentencePairListProcessed sentencePairListProcessed = new SentencePairListProcessed(listPairs.size());
for (Triple<String, String, String> t : listPairs) {
List<String> tokensL = tokenizeSentence(t.getFirst(), true);
List<String> tokensR = tokenizeSentence(t.getSecond(), true);
List<String> tokens = new ArrayList<>(maxTokens);
int maxLength = maxTokens;
if (prependToken != null)
maxLength--;
if (appendToken != null)
maxLength -= 2;
if (tokensL.size() + tokensR.size() > maxLength) {
boolean shortOnL = tokensL.size() < tokensR.size();
int shortSize = Math.min(tokensL.size(), tokensR.size());
if (shortSize > maxLength / 2) {
//both lists need to be sliced
tokensL.subList(maxLength / 2, tokensL.size()).clear(); //if maxsize/2 is odd pop extra on L side to match implementation in TF
tokensR.subList(maxLength - maxLength / 2, tokensR.size()).clear();
} else {
//slice longer list
if (shortOnL) {
//longer on R - slice R
tokensR.subList(maxLength - tokensL.size(), tokensR.size()).clear();
} else {
//longer on L - slice L
tokensL.subList(maxLength - tokensR.size(), tokensL.size()).clear();
}
}
}
if (prependToken != null)
tokens.add(prependToken);
tokens.addAll(tokensL);
if (appendToken != null)
tokens.add(appendToken);
int segIdOnesFrom = tokens.size();
tokens.addAll(tokensR);
if (appendToken != null)
tokens.add(appendToken);
sentencePairListProcessed.addProcessedToList(segIdOnesFrom, new Pair<>(tokens, t.getThird()));
}
sentencePairListProcessed.setMaxL(maxTokens);
return sentencePairListProcessed;
} }
private Pair<INDArray[], INDArray[]> convertMiniBatchLabels(List<Pair<List<String>, String>> tokenizedSentences, INDArray[] featureArray, int outLength) { private Pair<INDArray[], INDArray[]> convertMiniBatchLabels(List<Pair<List<String>, String>> tokenizedSentences, INDArray[] featureArray, int outLength) {
@ -316,6 +412,14 @@ public class BertIterator implements MultiDataSetIterator {
classLabels[i] = labels.indexOf(lbl); classLabels[i] = labels.indexOf(lbl);
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl); Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
} }
} else if (sentencePairProvider != null) {
numClasses = sentencePairProvider.numLabelClasses();
List<String> labels = sentencePairProvider.allLabels();
for (int i = 0; i < mbSize; i++) {
String lbl = tokenizedSentences.get(i).getRight();
classLabels[i] = labels.indexOf(lbl);
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
}
} else { } else {
throw new RuntimeException(); throw new RuntimeException();
} }
@ -392,16 +496,22 @@ public class BertIterator implements MultiDataSetIterator {
} }
private List<String> tokenizeSentence(String sentence) { private List<String> tokenizeSentence(String sentence) {
return tokenizeSentence(sentence, false);
}
private List<String> tokenizeSentence(String sentence, boolean ignorePrependAppend) {
Tokenizer t = tokenizerFactory.create(sentence); Tokenizer t = tokenizerFactory.create(sentence);
List<String> tokens = new ArrayList<>(); List<String> tokens = new ArrayList<>();
if (prependToken != null) if (prependToken != null && !ignorePrependAppend)
tokens.add(prependToken); tokens.add(prependToken);
while (t.hasMoreTokens()) { while (t.hasMoreTokens()) {
String token = t.nextToken(); String token = t.nextToken();
tokens.add(token); tokens.add(token);
} }
if (appendToken != null && !ignorePrependAppend)
tokens.add(appendToken);
return tokens; return tokens;
} }
@ -414,6 +524,13 @@ public class BertIterator implements MultiDataSetIterator {
return list; return list;
} }
private List<Triple<String, String, String>> addDummyLabelForPairs(List<Pair<String, String>> listOnlySentencePairs) {
List<Triple<String, String, String>> list = new ArrayList<>(listOnlySentencePairs.size());
for (Pair<String, String> p : listOnlySentencePairs) {
list.add(new Triple<String, String, String>(p.getFirst(), p.getSecond(), null));
}
return list;
}
@Override @Override
public boolean resetSupported() { public boolean resetSupported() {
@ -446,12 +563,14 @@ public class BertIterator implements MultiDataSetIterator {
protected boolean padMinibatches = false; protected boolean padMinibatches = false;
protected MultiDataSetPreProcessor preProcessor; protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null; protected LabeledSentenceProvider sentenceProvider = null;
protected LabeledPairSentenceProvider sentencePairProvider = null;
protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID; protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects? protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects?
protected BertSequenceMasker masker = new BertMaskedLMMasker(); protected BertSequenceMasker masker = new BertMaskedLMMasker();
protected UnsupervisedLabelFormat unsupervisedLabelFormat; protected UnsupervisedLabelFormat unsupervisedLabelFormat;
protected String maskToken; protected String maskToken;
protected String prependToken; protected String prependToken;
protected String appendToken;
/** /**
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details. * Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
@ -519,14 +638,21 @@ public class BertIterator implements MultiDataSetIterator {
} }
/** /**
* Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised * Specify the source of the data for classification.
* use case, the labels will be ignored.
*/ */
public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) { public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) {
this.sentenceProvider = sentenceProvider; this.sentenceProvider = sentenceProvider;
return this; return this;
} }
/**
* Specify the source of the data for classification on sentence pairs.
*/
public Builder sentencePairProvider(LabeledPairSentenceProvider sentencePairProvider) {
this.sentencePairProvider = sentencePairProvider;
return this;
}
/** /**
* Specify what arrays should be returned. See {@link BertIterator} for more details. * Specify what arrays should be returned. See {@link BertIterator} for more details.
*/ */
@ -591,6 +717,19 @@ public class BertIterator implements MultiDataSetIterator {
return this; return this;
} }
/**
* Append the specified token to the sequences, when doing training on sentence pairs.<br>
* Generally "[SEP]" is used
* No token in appended by default.
*
* @param appendToken Token at end of each sentence for pairs of sentences (null: no token will be appended)
* @return
*/
public Builder appendToken(String appendToken) {
this.appendToken = appendToken;
return this;
}
public BertIterator build() { public BertIterator build() {
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed"); Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required"); Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
@ -598,9 +737,69 @@ public class BertIterator implements MultiDataSetIterator {
Preconditions.checkState(task != Task.UNSUPERVISED || masker != null, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method"); Preconditions.checkState(task != Task.UNSUPERVISED || masker != null, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method");
Preconditions.checkState(task != Task.UNSUPERVISED || unsupervisedLabelFormat != null, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method"); Preconditions.checkState(task != Task.UNSUPERVISED || unsupervisedLabelFormat != null, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method");
Preconditions.checkState(task != Task.UNSUPERVISED || maskToken != null, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified"); Preconditions.checkState(task != Task.UNSUPERVISED || maskToken != null, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified");
if (sentencePairProvider != null) {
Preconditions.checkState(task == Task.SEQ_CLASSIFICATION, "Currently only supervised sequence classification is set up with sentence pairs. \".task(BertIterator.Task.SEQ_CLASSIFICATION)\" is required with a sentence pair provider");
Preconditions.checkState(featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID, "Currently only supervised sequence classification is set up with sentence pairs. \".featureArrays(FeatureArrays.INDICES_MASK_SEGMENTID)\" is required with a sentence pair provider");
Preconditions.checkState(lengthHandling == LengthHandling.FIXED_LENGTH, "Currently only fixed length is supported for sentence pairs. \".lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxLength)\" is required with a sentence pair provider");
Preconditions.checkState(sentencePairProvider != null, "Provide either a sentence provider or a sentence pair provider. Both cannot be non null");
}
if (appendToken != null) {
Preconditions.checkState(sentencePairProvider != null, "Tokens are only appended with sentence pairs. Sentence pair provider is not set. Set sentence pair provider.");
}
return new BertIterator(this); return new BertIterator(this);
} }
} }
private static class SentencePairListProcessed {
private int listLength = 0;
@Getter
private long[] segIdOnesFrom;
private int cursor = 0;
private SentenceListProcessed sentenceListProcessed;
private SentencePairListProcessed(int listLength) {
this.listLength = listLength;
segIdOnesFrom = new long[listLength];
sentenceListProcessed = new SentenceListProcessed(listLength);
}
private void addProcessedToList(long segIdIdx, Pair<List<String>, String> tokenizedSentencePairAndLabel) {
segIdOnesFrom[cursor] = segIdIdx;
sentenceListProcessed.addProcessedToList(tokenizedSentencePairAndLabel);
cursor++;
}
private void setMaxL(int maxL) {
sentenceListProcessed.setMaxL(maxL);
}
private int getMaxL() {
return sentenceListProcessed.getMaxL();
}
private List<Pair<List<String>, String>> getTokensAndLabelList() {
return sentenceListProcessed.getTokensAndLabelList();
}
}
private static class SentenceListProcessed {
private int listLength;
@Getter
@Setter
private int maxL;
@Getter
private List<Pair<List<String>, String>> tokensAndLabelList;
private SentenceListProcessed(int listLength) {
this.listLength = listLength;
tokensAndLabelList = new ArrayList<>(listLength);
}
private void addProcessedToList(Pair<List<String>, String> tokenizedSentenceAndLabel) {
tokensAndLabelList.add(tokenizedSentenceAndLabel);
}
}
} }

View File

@ -0,0 +1,60 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.iterator;
import org.nd4j.linalg.primitives.Triple;
import java.util.List;
/**
* LabeledPairSentenceProvider: a simple iterator interface over a pair of sentences/documents that have a label.<br>
*/
public interface LabeledPairSentenceProvider {
/**
* Are there more sentences/documents available?
*/
boolean hasNext();
/**
* @return Triple: two sentence/document texts and label
*/
Triple<String, String, String> nextSentencePair();
/**
* Reset the iterator - including shuffling the order, if necessary/appropriate
*/
void reset();
/**
* Return the total number of sentences, or -1 if not available
*/
int totalNumSentences();
/**
* Return the list of labels - this also defines the class/integer label assignment order
*/
List<String> allLabels();
/**
* Equivalent to allLabels().size()
*/
int numLabelClasses();
}

View File

@ -0,0 +1,135 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.iterator.provider;
import lombok.NonNull;
import org.deeplearning4j.iterator.LabeledPairSentenceProvider;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.linalg.util.MathUtils;
import java.util.*;
/**
* Iterate over a pair of sentences/documents,
* where the sentences and labels are provided in lists.
*
*/
public class CollectionLabeledPairSentenceProvider implements LabeledPairSentenceProvider {
private final List<String> sentenceL;
private final List<String> sentenceR;
private final List<String> labels;
private final Random rng;
private final int[] order;
private final List<String> allLabels;
private int cursor = 0;
/**
* Lists containing sentences to iterate over with a third for labels
* Sentences in the same position in the first two lists are considered a pair
* @param sentenceL
* @param sentenceR
* @param labelsForSentences
*/
public CollectionLabeledPairSentenceProvider(@NonNull List<String> sentenceL, @NonNull List<String> sentenceR,
@NonNull List<String> labelsForSentences) {
this(sentenceL, sentenceR, labelsForSentences, new Random());
}
/**
* Lists containing sentences to iterate over with a third for labels
* Sentences in the same position in the first two lists are considered a pair
* @param sentenceL
* @param sentenceR
* @param labelsForSentences
* @param rng If null, list order is not shuffled
*/
public CollectionLabeledPairSentenceProvider(@NonNull List<String> sentenceL, List<String> sentenceR, @NonNull List<String> labelsForSentences,
Random rng) {
if (sentenceR.size() != sentenceL.size()) {
throw new IllegalArgumentException("Sentence lists must be same size (first list size: "
+ sentenceL.size() + ", second list size: " + sentenceR.size() + ")");
}
if (sentenceR.size() != labelsForSentences.size()) {
throw new IllegalArgumentException("Sentence pairs and labels must be same size (sentence pair size: "
+ sentenceR.size() + ", labels size: " + labelsForSentences.size() + ")");
}
this.sentenceL = sentenceL;
this.sentenceR = sentenceR;
this.labels = labelsForSentences;
this.rng = rng;
if (rng == null) {
order = null;
} else {
order = new int[sentenceR.size()];
for (int i = 0; i < sentenceR.size(); i++) {
order[i] = i;
}
MathUtils.shuffleArray(order, rng);
}
//Collect set of unique labels for all sentences
Set<String> uniqueLabels = new HashSet<>(labelsForSentences);
allLabels = new ArrayList<>(uniqueLabels);
Collections.sort(allLabels);
}
@Override
public boolean hasNext() {
return cursor < sentenceR.size();
}
@Override
public Triple<String, String, String> nextSentencePair() {
Preconditions.checkState(hasNext(),"No next element available");
int idx;
if (rng == null) {
idx = cursor++;
} else {
idx = order[cursor++];
}
return new Triple<>(sentenceL.get(idx), sentenceR.get(idx), labels.get(idx));
}
@Override
public void reset() {
cursor = 0;
if (rng != null) {
MathUtils.shuffleArray(order, rng);
}
}
@Override
public int totalNumSentences() {
return sentenceR.size();
}
@Override
public List<String> allLabels() {
return allLabels;
}
@Override
public int numLabelClasses() {
return allLabels.size();
}
}

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.iterator.provider;
import lombok.NonNull; import lombok.NonNull;
import org.deeplearning4j.iterator.LabeledSentenceProvider; import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.MathUtils; import org.nd4j.linalg.util.MathUtils;
@ -40,15 +41,15 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide
private int cursor = 0; private int cursor = 0;
public CollectionLabeledSentenceProvider(@NonNull List<String> sentences, public CollectionLabeledSentenceProvider(@NonNull List<String> sentences,
@NonNull List<String> labelsForSentences) { @NonNull List<String> labelsForSentences) {
this(sentences, labelsForSentences, new Random()); this(sentences, labelsForSentences, new Random());
} }
public CollectionLabeledSentenceProvider(@NonNull List<String> sentences, @NonNull List<String> labelsForSentences, public CollectionLabeledSentenceProvider(@NonNull List<String> sentences, @NonNull List<String> labelsForSentences,
Random rng) { Random rng) {
if (sentences.size() != labelsForSentences.size()) { if (sentences.size() != labelsForSentences.size()) {
throw new IllegalArgumentException("Sentences and labels must be same size (sentences size: " throw new IllegalArgumentException("Sentences and labels must be same size (sentences size: "
+ sentences.size() + ", labels size: " + labelsForSentences.size() + ")"); + sentences.size() + ", labels size: " + labelsForSentences.size() + ")");
} }
this.sentences = sentences; this.sentences = sentences;
@ -66,10 +67,7 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide
} }
//Collect set of unique labels for all sentences //Collect set of unique labels for all sentences
Set<String> uniqueLabels = new HashSet<>(); Set<String> uniqueLabels = new HashSet<>(labelsForSentences);
for (String s : labelsForSentences) {
uniqueLabels.add(s);
}
allLabels = new ArrayList<>(uniqueLabels); allLabels = new ArrayList<>(uniqueLabels);
Collections.sort(allLabels); Collections.sort(allLabels);
} }
@ -81,6 +79,7 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide
@Override @Override
public Pair<String, String> nextSentence() { public Pair<String, String> nextSentence() {
Preconditions.checkState(hasNext(), "No next element available");
int idx; int idx;
if (rng == null) { if (rng == null) {
idx = cursor++; idx = cursor++;

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -27,6 +28,7 @@ import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.io.File; import java.io.File;
@ -43,7 +45,8 @@ public class TestBertIterator extends BaseDL4JTest {
private File pathToVocab = Resources.asFile("other/vocab.txt"); private File pathToVocab = Resources.asFile("other/vocab.txt");
private static Charset c = StandardCharsets.UTF_8; private static Charset c = StandardCharsets.UTF_8;
public TestBertIterator() throws IOException{ } public TestBertIterator() throws IOException {
}
@Test(timeout = 20000L) @Test(timeout = 20000L)
public void testBertSequenceClassification() throws Exception { public void testBertSequenceClassification() throws Exception {
@ -74,8 +77,8 @@ public class TestBertIterator extends BaseDL4JTest {
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens = t.create(toTokenize1).getTokens(); List<String> tokens = t.create(toTokenize1).getTokens();
Map<String,Integer> m = t.getVocab(); Map<String, Integer> m = t.getVocab();
for( int i=0; i<tokens.size(); i++ ){ for (int i = 0; i < tokens.size(); i++) {
int idx = m.get(tokens.get(i)); int idx = m.get(tokens.get(i));
expEx0.putScalar(0, i, idx); expEx0.putScalar(0, i, idx);
expM0.putScalar(0, i, 1); expM0.putScalar(0, i, 1);
@ -84,9 +87,9 @@ public class TestBertIterator extends BaseDL4JTest {
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens2 = t.create(toTokenize2).getTokens(); List<String> tokens2 = t.create(toTokenize2).getTokens();
for( int i=0; i<tokens2.size(); i++ ){ for (int i = 0; i < tokens2.size(); i++) {
String token = tokens2.get(i); String token = tokens2.get(i);
if(!m.containsKey(token)){ if (!m.containsKey(token)) {
throw new IllegalStateException("Unknown token: \"" + token + "\""); throw new IllegalStateException("Unknown token: \"" + token + "\"");
} }
int idx = m.get(token); int idx = m.get(token);
@ -99,15 +102,15 @@ public class TestBertIterator extends BaseDL4JTest {
assertEquals(expF, mds.getFeatures(0)); assertEquals(expF, mds.getFeatures(0));
assertEquals(expM, mds.getFeaturesMaskArray(0)); assertEquals(expM, mds.getFeaturesMaskArray(0));
assertEquals(expF,b.featurizeSentences(forInference).getFirst()[0]); assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]);
assertEquals(expM,b.featurizeSentences(forInference).getSecond()[0]); assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
b.next(); //pop the third element
assertFalse(b.hasNext()); assertFalse(b.hasNext());
b.reset(); b.reset();
assertTrue(b.hasNext()); assertTrue(b.hasNext());
MultiDataSet mds2 = b.next();
forInference.set(0,toTokenize2); forInference.set(0, toTokenize2);
//Same thing, but with segment ID also //Same thing, but with segment ID also
b = BertIterator.builder() b = BertIterator.builder()
.tokenizer(t) .tokenizer(t)
@ -120,11 +123,12 @@ public class TestBertIterator extends BaseDL4JTest {
.build(); .build();
mds = b.next(); mds = b.next();
assertEquals(2, mds.getFeatures().length); assertEquals(2, mds.getFeatures().length);
assertEquals(2,b.featurizeSentences(forInference).getFirst().length); //assertEquals(2, mds.getFeaturesMaskArrays().length); second element is null...
assertEquals(2, b.featurizeSentences(forInference).getFirst().length);
//Segment ID should be all 0s for single segment task //Segment ID should be all 0s for single segment task
INDArray segmentId = expM.like(); INDArray segmentId = expM.like();
assertEquals(segmentId, mds.getFeatures(1)); assertEquals(segmentId, mds.getFeatures(1));
assertEquals(segmentId,b.featurizeSentences(forInference).getFirst()[1]); assertEquals(segmentId, b.featurizeSentences(forInference).getFirst()[1]);
} }
@Test(timeout = 20000L) @Test(timeout = 20000L)
@ -152,9 +156,10 @@ public class TestBertIterator extends BaseDL4JTest {
System.out.println(mds.getLabels(0)); System.out.println(mds.getLabels(0));
System.out.println(mds.getLabelsMaskArray(0)); System.out.println(mds.getLabelsMaskArray(0));
b.next(); //pop the third element
assertFalse(b.hasNext()); assertFalse(b.hasNext());
b.reset(); b.reset();
mds = b.next(); assertTrue(b.hasNext());
} }
@Test(timeout = 20000L) @Test(timeout = 20000L)
@ -168,8 +173,9 @@ public class TestBertIterator extends BaseDL4JTest {
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens = t.create(toTokenize1).getTokens(); List<String> tokens = t.create(toTokenize1).getTokens();
Map<String,Integer> m = t.getVocab(); System.out.println(tokens);
for( int i=0; i<tokens.size(); i++ ){ Map<String, Integer> m = t.getVocab();
for (int i = 0; i < tokens.size(); i++) {
int idx = m.get(tokens.get(i)); int idx = m.get(tokens.get(i));
expEx0.putScalar(0, i, idx); expEx0.putScalar(0, i, idx);
expM0.putScalar(0, i, 1); expM0.putScalar(0, i, 1);
@ -178,9 +184,10 @@ public class TestBertIterator extends BaseDL4JTest {
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens2 = t.create(toTokenize2).getTokens(); List<String> tokens2 = t.create(toTokenize2).getTokens();
for( int i=0; i<tokens2.size(); i++ ){ System.out.println(tokens2);
for (int i = 0; i < tokens2.size(); i++) {
String token = tokens2.get(i); String token = tokens2.get(i);
if(!m.containsKey(token)){ if (!m.containsKey(token)) {
throw new IllegalStateException("Unknown token: \"" + token + "\""); throw new IllegalStateException("Unknown token: \"" + token + "\"");
} }
int idx = m.get(token); int idx = m.get(token);
@ -210,9 +217,9 @@ public class TestBertIterator extends BaseDL4JTest {
long[] expShape = new long[]{2, 14}; long[] expShape = new long[]{2, 14};
assertArrayEquals(expShape, mds.getFeatures(0).shape()); assertArrayEquals(expShape, mds.getFeatures(0).shape());
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeatures(0)); assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeatures(0));
assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeaturesMaskArray(0)); assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeaturesMaskArray(0));
assertEquals(mds.getFeatures(0),b.featurizeSentences(forInference).getFirst()[0]); assertEquals(mds.getFeatures(0), b.featurizeSentences(forInference).getFirst()[0]);
assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]); assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]);
//Clip only: clip to maximum, but don't pad if less //Clip only: clip to maximum, but don't pad if less
@ -236,15 +243,17 @@ public class TestBertIterator extends BaseDL4JTest {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
String toTokenize1 = "I saw a girl with a telescope."; String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
String toTokenize3 = "Goodnight noises everywhere";
List<String> forInference = new ArrayList<>(); List<String> forInference = new ArrayList<>();
forInference.add(toTokenize1); forInference.add(toTokenize1);
forInference.add(toTokenize2); forInference.add(toTokenize2);
forInference.add(toTokenize3);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens = t.create(toTokenize1).getTokens(); List<String> tokens = t.create(toTokenize1).getTokens();
Map<String,Integer> m = t.getVocab(); Map<String, Integer> m = t.getVocab();
for( int i=0; i<tokens.size(); i++ ){ for (int i = 0; i < tokens.size(); i++) {
int idx = m.get(tokens.get(i)); int idx = m.get(tokens.get(i));
expEx0.putScalar(0, i, idx); expEx0.putScalar(0, i, idx);
expM0.putScalar(0, i, 1); expM0.putScalar(0, i, 1);
@ -253,9 +262,9 @@ public class TestBertIterator extends BaseDL4JTest {
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens2 = t.create(toTokenize2).getTokens(); List<String> tokens2 = t.create(toTokenize2).getTokens();
for( int i=0; i<tokens2.size(); i++ ){ for (int i = 0; i < tokens2.size(); i++) {
String token = tokens2.get(i); String token = tokens2.get(i);
if(!m.containsKey(token)){ if (!m.containsKey(token)) {
throw new IllegalStateException("Unknown token: \"" + token + "\""); throw new IllegalStateException("Unknown token: \"" + token + "\"");
} }
int idx = m.get(token); int idx = m.get(token);
@ -263,14 +272,27 @@ public class TestBertIterator extends BaseDL4JTest {
expM1.putScalar(0, i, 1); expM1.putScalar(0, i, 1);
} }
INDArray zeros = Nd4j.create(DataType.INT, 2, 16); INDArray expEx3 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM3 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens3 = t.create(toTokenize3).getTokens();
for (int i = 0; i < tokens3.size(); i++) {
String token = tokens3.get(i);
if (!m.containsKey(token)) {
throw new IllegalStateException("Unknown token: \"" + token + "\"");
}
int idx = m.get(token);
expEx3.putScalar(0, i, idx);
expM3.putScalar(0, i, 1);
}
INDArray expF = Nd4j.vstack(expEx0, expEx1, zeros); INDArray zeros = Nd4j.create(DataType.INT, 1, 16);
INDArray expM = Nd4j.vstack(expM0, expM1, zeros); INDArray expF = Nd4j.vstack(expEx0, expEx1, expEx3, zeros);
INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {0, 0}, {0, 0}}); INDArray expM = Nd4j.vstack(expM0, expM1, expM3, zeros);
INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {1, 0}, {0, 0}});
INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1); INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1);
expLM.putScalar(0, 0, 1); expLM.putScalar(0, 0, 1);
expLM.putScalar(1, 0, 1); expLM.putScalar(1, 0, 1);
expLM.putScalar(2, 0, 1);
//-------------------------------------------------------------- //--------------------------------------------------------------
@ -305,9 +327,234 @@ public class TestBertIterator extends BaseDL4JTest {
assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
} }
@Test
public void testSentencePairsSingle() throws IOException {
String shortSent = "I saw a girl with a telescope.";
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
boolean prependAppend;
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
int shortL = t.create(shortSent).countTokens();
int longL = t.create(longSent).countTokens();
Triple<MultiDataSet, MultiDataSet, MultiDataSet> multiDataSetTriple;
MultiDataSet shortLongPair, shortSentence, longSentence;
// check for pair max length exactly equal to sum of lengths - pop neither no padding
// should be the same as hstack with segment ids 1 for second sentence
prependAppend = true;
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend);
shortLongPair = multiDataSetTriple.getFirst();
shortSentence = multiDataSetTriple.getSecond();
longSentence = multiDataSetTriple.getThird();
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
longSentence.getFeatures(1).addi(1);
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
//check for pair max length greater than sum of lengths - pop neither with padding
// features should be the same as hstack of shorter and longer padded with prepend/append
// segment id should 1 only in the longer for part of the length of the sentence
prependAppend = true;
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend);
shortLongPair = multiDataSetTriple.getFirst();
shortSentence = multiDataSetTriple.getSecond();
longSentence = multiDataSetTriple.getThird();
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
longSentence.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
//check for pair max length less than shorter sentence - pop both
//should be the same as hstack with segment ids 1 for second sentence if no prepend/append
int maxL = shortL - 2;
prependAppend = false;
multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend);
shortLongPair = multiDataSetTriple.getFirst();
shortSentence = multiDataSetTriple.getSecond();
longSentence = multiDataSetTriple.getThird();
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
longSentence.getFeatures(1).addi(1);
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
}
@Test
public void testSentencePairsUnequalLengths() throws IOException {
//check for pop only longer (i.e between longer and longer + shorter), first row pop from second sentence, next row pop from first sentence, nothing to pop in the third row
//should be identical to hstack if there is no append, prepend
//batch size is 2
int mbS = 4;
String shortSent = "I saw a girl with a telescope.";
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
String sent1 = "Goodnight noises everywhere"; //shorter than shortSent - no popping
String sent2 = "Goodnight moon"; //shorter than shortSent - no popping
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
int shortL = t.create(shortSent).countTokens();
int longL = t.create(longSent).countTokens();
int sent1L = t.create(sent1).countTokens();
int sent2L = t.create(sent2).countTokens();
//won't check 2*shortL + 1 because this will always pop on the left
for (int maxL = longL + shortL - 1; maxL > 2 * shortL; maxL--) {
MultiDataSet leftMDS = BertIterator.builder()
.tokenizer(t)
.minibatchSize(mbS)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either
.sentenceProvider(new TestSentenceProvider())
.padMinibatches(true)
.build().next();
MultiDataSet rightMDS = BertIterator.builder()
.tokenizer(t)
.minibatchSize(mbS)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either
.sentenceProvider(new TestSentenceProvider(true))
.padMinibatches(true)
.build().next();
MultiDataSet pairMDS = BertIterator.builder()
.tokenizer(t)
.minibatchSize(mbS)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) //random big num guaranteed to be longer than either
.sentencePairProvider(new TestSentencePairProvider())
.padMinibatches(true)
.build().next();
//Left sentences here are {{shortSent},
// {longSent},
// {Sent1}}
//Right sentences here are {{longSent},
// {shortSent},
// {Sent2}}
//The sentence pairs here are {{shortSent,longSent},
// {longSent,shortSent}
// {Sent1, Sent2}}
//CHECK FEATURES
INDArray combinedFeat = Nd4j.create(DataType.INT,mbS,maxL);
//left side
INDArray leftFeatures = leftMDS.getFeatures(0);
INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL));
INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL));
INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0,sent1L));
//right side
INDArray rightFeatures = rightMDS.getFeatures(0);
INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL));
INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL));
INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0,sent2L));
//expected pair
combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat,topRSentFeat));
combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat,midRSentFeat));
combinedFeat.getRow(2).get(NDArrayIndex.interval(0,sent1L+sent2L)).addi(Nd4j.hstack(bottomLSentFeat,bottomRSentFeat));
assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]);
assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape());
assertEquals(combinedFeat, pairMDS.getFeatures(0));
//CHECK SEGMENT ID
INDArray combinedFetSeg = Nd4j.create(DataType.INT, mbS, maxL);
combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1);
combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1);
combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L+sent2L)).addi(1);
assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape());
assertEquals(maxL, combinedFetSeg.shape()[1]);
assertEquals(combinedFetSeg, pairMDS.getFeatures(1));
}
}
@Test
public void testSentencePairFeaturizer() throws IOException {
String shortSent = "I saw a girl with a telescope.";
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
List<Pair<String, String>> listSentencePair = new ArrayList<>();
listSentencePair.add(new Pair<>(shortSent, longSent));
listSentencePair.add(new Pair<>(longSent, shortSent));
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
BertIterator b = BertIterator.builder()
.tokenizer(t)
.minibatchSize(2)
.padMinibatches(true)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128)
.sentencePairProvider(new TestSentencePairProvider())
.prependToken("[CLS]")
.appendToken("[SEP]")
.build();
MultiDataSet mds = b.next();
INDArray[] featuresArr = mds.getFeatures();
INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays();
Pair<INDArray[], INDArray[]> p = b.featurizeSentencePairs(listSentencePair);
assertEquals(p.getFirst().length, 2);
assertEquals(featuresArr[0], p.getFirst()[0]);
assertEquals(featuresArr[1], p.getFirst()[1]);
//assertEquals(p.getSecond().length, 2);
assertEquals(featuresMaskArr[0], p.getSecond()[0]);
//assertEquals(featuresMaskArr[1], p.getSecond()[1]);
}
/**
* Returns three multidatasets from bert iterator based on given max lengths and whether to prepend/append
* Idea is the sentence pair dataset can be constructed from the single sentence datasets
* First one is constructed from a sentence pair "I saw a girl with a telescope." & "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
* Second one is constructed from the left of the sentence pair i.e "I saw a girl with a telescope."
* Third one is constructed from the right of the sentence pair i.e "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
*/
private Triple<MultiDataSet, MultiDataSet, MultiDataSet> generateMultiDataSets(Triple<Integer, Integer, Integer> maxLengths, boolean prependAppend) throws IOException {
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
int maxforPair = maxLengths.getFirst();
int maxPartOne = maxLengths.getSecond();
int maxPartTwo = maxLengths.getThird();
BertIterator.Builder commonBuilder;
commonBuilder = BertIterator.builder()
.tokenizer(t)
.minibatchSize(1)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION);
BertIterator shortLongPairFirstIter = commonBuilder
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair)
.sentencePairProvider(new TestSentencePairProvider())
.prependToken(prependAppend ? "[CLS]" : null)
.appendToken(prependAppend ? "[SEP]" : null)
.build();
BertIterator shortFirstIter = commonBuilder
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne)
.sentenceProvider(new TestSentenceProvider())
.prependToken(prependAppend ? "[CLS]" : null)
.appendToken(prependAppend ? "[SEP]" : null)
.build();
BertIterator longFirstIter = commonBuilder
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo)
.sentenceProvider(new TestSentenceProvider(true))
.prependToken(null)
.appendToken(prependAppend ? "[SEP]" : null)
.build();
return new Triple<>(shortLongPairFirstIter.next(), shortFirstIter.next(), longFirstIter.next());
}
private static class TestSentenceProvider implements LabeledSentenceProvider { private static class TestSentenceProvider implements LabeledSentenceProvider {
private int pos = 0; private int pos = 0;
private boolean invert;
private TestSentenceProvider() {
this.invert = false;
}
private TestSentenceProvider(boolean invert) {
this.invert = invert;
}
@Override @Override
public boolean hasNext() { public boolean hasNext() {
@ -317,10 +564,20 @@ public class TestBertIterator extends BaseDL4JTest {
@Override @Override
public Pair<String, String> nextSentence() { public Pair<String, String> nextSentence() {
Preconditions.checkState(hasNext()); Preconditions.checkState(hasNext());
if(pos++ == 0){ if (pos == 0) {
return new Pair<>("I saw a girl with a telescope.", "positive"); pos++;
} else { if (!invert) return new Pair<>("I saw a girl with a telescope.", "positive");
return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative"); return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative");
} else {
if (pos == 1) {
pos++;
if (!invert) return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative");
return new Pair<>("I saw a girl with a telescope.", "positive");
}
pos++;
if (!invert)
return new Pair<>("Goodnight noises everywhere", "positive");
return new Pair<>("Goodnight moon", "positive");
} }
} }
@ -331,8 +588,54 @@ public class TestBertIterator extends BaseDL4JTest {
@Override @Override
public int totalNumSentences() { public int totalNumSentences() {
return 3;
}
@Override
public List<String> allLabels() {
return Arrays.asList("positive", "negative");
}
@Override
public int numLabelClasses() {
return 2; return 2;
} }
}
private static class TestSentencePairProvider implements LabeledPairSentenceProvider {
private int pos = 0;
@Override
public boolean hasNext() {
return pos < totalNumSentences();
}
@Override
public Triple<String, String, String> nextSentencePair() {
Preconditions.checkState(hasNext());
if (pos == 0) {
pos++;
return new Triple<>("I saw a girl with a telescope.", "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "positive");
} else {
if (pos == 1) {
pos++;
return new Triple<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "I saw a girl with a telescope.", "negative");
}
pos++;
return new Triple<>("Goodnight noises everywhere", "Goodnight moon", "positive");
}
}
@Override
public void reset() {
pos = 0;
}
@Override
public int totalNumSentences() {
return 3;
}
@Override @Override
public List<String> allLabels() { public List<String> allLabels() {