diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java index b5ca6c91a..40c43113c 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java @@ -34,6 +34,7 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.primitives.Pair; +import org.nd4j.linalg.primitives.Triple; import java.util.ArrayList; import java.util.Arrays; @@ -85,10 +86,20 @@ import java.util.Map; *
  * {@code
  *          BertIterator b;
+ *          Pair featuresAndMask;
+ *          INDArray[] features;
+ *          INDArray[] featureMasks;
+ *
+ *          //With sentences
  *          List forInference;
- *          Pair featuresAndMask = b.featurizeSentences(forInference);
- *          INDArray[] features = featuresAndMask.getFirst();
- *          INDArray[] featureMasks = featuresAndMask.getSecond();
+ *          featuresAndMask = b.featurizeSentences(forInference);
+ *
+ *          //OR with sentence pairs
+ *          List> forInferencePair};
+ *          featuresAndMask = b.featurizeSentencePairs(forInference);
+ *
+ *          features = featuresAndMask.getFirst();
+ *          featureMasks = featuresAndMask.getSecond();
  * }
  * 
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.
@@ -135,6 +146,7 @@ public class BertIterator implements MultiDataSetIterator { @Setter protected MultiDataSetPreProcessor preProcessor; protected LabeledSentenceProvider sentenceProvider = null; + protected LabeledPairSentenceProvider sentencePairProvider = null; protected LengthHandling lengthHandling; protected FeatureArrays featureArrays; protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects? @@ -142,6 +154,7 @@ public class BertIterator implements MultiDataSetIterator { protected UnsupervisedLabelFormat unsupervisedLabelFormat = null; protected String maskToken; protected String prependToken; + protected String appendToken; protected List vocabKeysAsList; @@ -154,6 +167,7 @@ public class BertIterator implements MultiDataSetIterator { this.padMinibatches = b.padMinibatches; this.preProcessor = b.preProcessor; this.sentenceProvider = b.sentenceProvider; + this.sentencePairProvider = b.sentencePairProvider; this.lengthHandling = b.lengthHandling; this.featureArrays = b.featureArrays; this.vocabMap = b.vocabMap; @@ -161,11 +175,14 @@ public class BertIterator implements MultiDataSetIterator { this.unsupervisedLabelFormat = b.unsupervisedLabelFormat; this.maskToken = b.maskToken; this.prependToken = b.prependToken; + this.appendToken = b.appendToken; } @Override public boolean hasNext() { - return sentenceProvider.hasNext(); + if (sentenceProvider != null) + return sentenceProvider.hasNext(); + return sentencePairProvider.hasNext(); } @Override @@ -181,29 +198,38 @@ public class BertIterator implements MultiDataSetIterator { @Override public MultiDataSet next(int num) { Preconditions.checkState(hasNext(), "No next element available"); - - List> list = new ArrayList<>(num); + List, String>> tokensAndLabelList; int mbSize = 0; + int outLength; + long[] segIdOnesFrom = null; if (sentenceProvider != null) { + List> list = new ArrayList<>(num); while (sentenceProvider.hasNext() && mbSize++ < num) { list.add(sentenceProvider.nextSentence()); } + SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(list); + tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList(); + outLength = sentenceListProcessed.getMaxL(); + } else if (sentencePairProvider != null) { + List> listPairs = new ArrayList<>(num); + while (sentencePairProvider.hasNext() && mbSize++ < num) { + listPairs.add(sentencePairProvider.nextSentencePair()); + } + SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(listPairs); + tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList(); + outLength = sentencePairListProcessed.getMaxL(); + segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom(); } else { //TODO - other types of iterators... throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented"); } - - Pair, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list); - List, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight(); - int outLength = outLTokenizedSentencesPair.getLeft(); - - Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength); + Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom); INDArray[] featureArray = featuresAndMaskArraysPair.getFirst(); INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond(); - Pair labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength); + Pair labelsAndMaskArraysPair = convertMiniBatchLabels(tokensAndLabelList, featureArray, outLength); INDArray[] labelArray = labelsAndMaskArraysPair.getFirst(); INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond(); @@ -224,32 +250,59 @@ public class BertIterator implements MultiDataSetIterator { public Pair featurizeSentences(List listOnlySentences) { List> sentencesWithNullLabel = addDummyLabel(listOnlySentences); + SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(sentencesWithNullLabel); + List, String>> tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList(); + int outLength = sentenceListProcessed.getMaxL(); - Pair, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel); - List, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight(); - int outLength = outLTokenizedSentencesPair.getLeft(); - - Pair featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength); if (preProcessor != null) { + Pair featureFeatureMasks = convertMiniBatchFeatures(tokensAndLabelList, outLength, null); MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null); preProcessor.preProcess(dummyMDS); - return new Pair(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays()); + return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays()); } - return convertMiniBatchFeatures(tokenizedSentences, outLength); + return convertMiniBatchFeatures(tokensAndLabelList, outLength, null); } - private Pair convertMiniBatchFeatures(List, String>> tokenizedSentences, int outLength) { - int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size(); + /** + * For use during inference. Will convert a given pair of a list of sentences to features and feature masks as appropriate. + * + * @param listOnlySentencePairs + * @return Pair of INDArrays[], first element is feature arrays and the second is the masks array + */ + public Pair featurizeSentencePairs(List> listOnlySentencePairs) { + Preconditions.checkState(sentencePairProvider != null, "The featurizeSentencePairs method is meant for inference with sentence pairs. Use only when the sentence pair provider is set (i.e not null)."); + + List> sentencePairsWithNullLabel = addDummyLabelForPairs(listOnlySentencePairs); + SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(sentencePairsWithNullLabel); + List, String>> tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList(); + int outLength = sentencePairListProcessed.getMaxL(); + long[] segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom(); + if (preProcessor != null) { + Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom); + MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featuresAndMaskArraysPair.getFirst(), null, featuresAndMaskArraysPair.getSecond(), null); + preProcessor.preProcess(dummyMDS); + return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays()); + } + return convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom); + } + + private Pair convertMiniBatchFeatures(List, String>> tokensAndLabelList, int outLength, long[] segIdOnesFrom) { + int mbPadded = padMinibatches ? minibatchSize : tokensAndLabelList.size(); int[][] outIdxs = new int[mbPadded][outLength]; int[][] outMask = new int[mbPadded][outLength]; - for (int i = 0; i < tokenizedSentences.size(); i++) { - Pair, String> p = tokenizedSentences.get(i); + int[][] outSegmentId = null; + if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) + outSegmentId = new int[mbPadded][outLength]; + for (int i = 0; i < tokensAndLabelList.size(); i++) { + Pair, String> p = tokensAndLabelList.get(i); List t = p.getFirst(); for (int j = 0; j < outLength && j < t.size(); j++) { Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j)); int idx = vocabMap.get(t.get(j)); outIdxs[i][j] = idx; outMask[i][j] = 1; + if (segIdOnesFrom != null && j >= segIdOnesFrom[i]) + outSegmentId[i][j] = 1; } } @@ -260,8 +313,7 @@ public class BertIterator implements MultiDataSetIterator { INDArray[] f; INDArray[] fm; if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) { - //For now: always segment index 0 (only single s sequence input supported) - outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength); + outSegmentIdArr = Nd4j.createFromArray(outSegmentId); f = new INDArray[]{outIdxsArr, outSegmentIdArr}; fm = new INDArray[]{outMaskArr, null}; } else { @@ -271,16 +323,15 @@ public class BertIterator implements MultiDataSetIterator { return new Pair<>(f, fm); } - private Pair, String>>> tokenizeMiniBatch(List> list) { + private SentenceListProcessed tokenizeMiniBatch(List> list) { //Get and tokenize the sentences for this minibatch - List, String>> tokenizedSentences = new ArrayList<>(list.size()); + SentenceListProcessed sentenceListProcessed = new SentenceListProcessed(list.size()); int longestSeq = -1; for (Pair p : list) { List tokens = tokenizeSentence(p.getFirst()); - tokenizedSentences.add(new Pair<>(tokens, p.getSecond())); + sentenceListProcessed.addProcessedToList(new Pair<>(tokens, p.getSecond())); longestSeq = Math.max(longestSeq, tokens.size()); } - //Determine output array length... int outLength; switch (lengthHandling) { @@ -296,7 +347,52 @@ public class BertIterator implements MultiDataSetIterator { default: throw new RuntimeException("Not implemented length handling mode: " + lengthHandling); } - return new Pair<>(outLength, tokenizedSentences); + sentenceListProcessed.setMaxL(outLength); + return sentenceListProcessed; + } + + private SentencePairListProcessed tokenizePairsMiniBatch(List> listPairs) { + SentencePairListProcessed sentencePairListProcessed = new SentencePairListProcessed(listPairs.size()); + for (Triple t : listPairs) { + List tokensL = tokenizeSentence(t.getFirst(), true); + List tokensR = tokenizeSentence(t.getSecond(), true); + List tokens = new ArrayList<>(maxTokens); + int maxLength = maxTokens; + if (prependToken != null) + maxLength--; + if (appendToken != null) + maxLength -= 2; + if (tokensL.size() + tokensR.size() > maxLength) { + boolean shortOnL = tokensL.size() < tokensR.size(); + int shortSize = Math.min(tokensL.size(), tokensR.size()); + if (shortSize > maxLength / 2) { + //both lists need to be sliced + tokensL.subList(maxLength / 2, tokensL.size()).clear(); //if maxsize/2 is odd pop extra on L side to match implementation in TF + tokensR.subList(maxLength - maxLength / 2, tokensR.size()).clear(); + } else { + //slice longer list + if (shortOnL) { + //longer on R - slice R + tokensR.subList(maxLength - tokensL.size(), tokensR.size()).clear(); + } else { + //longer on L - slice L + tokensL.subList(maxLength - tokensR.size(), tokensL.size()).clear(); + } + } + } + if (prependToken != null) + tokens.add(prependToken); + tokens.addAll(tokensL); + if (appendToken != null) + tokens.add(appendToken); + int segIdOnesFrom = tokens.size(); + tokens.addAll(tokensR); + if (appendToken != null) + tokens.add(appendToken); + sentencePairListProcessed.addProcessedToList(segIdOnesFrom, new Pair<>(tokens, t.getThird())); + } + sentencePairListProcessed.setMaxL(maxTokens); + return sentencePairListProcessed; } private Pair convertMiniBatchLabels(List, String>> tokenizedSentences, INDArray[] featureArray, int outLength) { @@ -316,6 +412,14 @@ public class BertIterator implements MultiDataSetIterator { classLabels[i] = labels.indexOf(lbl); Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl); } + } else if (sentencePairProvider != null) { + numClasses = sentencePairProvider.numLabelClasses(); + List labels = sentencePairProvider.allLabels(); + for (int i = 0; i < mbSize; i++) { + String lbl = tokenizedSentences.get(i).getRight(); + classLabels[i] = labels.indexOf(lbl); + Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl); + } } else { throw new RuntimeException(); } @@ -392,16 +496,22 @@ public class BertIterator implements MultiDataSetIterator { } private List tokenizeSentence(String sentence) { + return tokenizeSentence(sentence, false); + } + + private List tokenizeSentence(String sentence, boolean ignorePrependAppend) { Tokenizer t = tokenizerFactory.create(sentence); List tokens = new ArrayList<>(); - if (prependToken != null) + if (prependToken != null && !ignorePrependAppend) tokens.add(prependToken); while (t.hasMoreTokens()) { String token = t.nextToken(); tokens.add(token); } + if (appendToken != null && !ignorePrependAppend) + tokens.add(appendToken); return tokens; } @@ -414,6 +524,13 @@ public class BertIterator implements MultiDataSetIterator { return list; } + private List> addDummyLabelForPairs(List> listOnlySentencePairs) { + List> list = new ArrayList<>(listOnlySentencePairs.size()); + for (Pair p : listOnlySentencePairs) { + list.add(new Triple(p.getFirst(), p.getSecond(), null)); + } + return list; + } @Override public boolean resetSupported() { @@ -446,12 +563,14 @@ public class BertIterator implements MultiDataSetIterator { protected boolean padMinibatches = false; protected MultiDataSetPreProcessor preProcessor; protected LabeledSentenceProvider sentenceProvider = null; + protected LabeledPairSentenceProvider sentencePairProvider = null; protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID; protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects? protected BertSequenceMasker masker = new BertMaskedLMMasker(); protected UnsupervisedLabelFormat unsupervisedLabelFormat; protected String maskToken; protected String prependToken; + protected String appendToken; /** * Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details. @@ -519,14 +638,21 @@ public class BertIterator implements MultiDataSetIterator { } /** - * Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised - * use case, the labels will be ignored. + * Specify the source of the data for classification. */ public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) { this.sentenceProvider = sentenceProvider; return this; } + /** + * Specify the source of the data for classification on sentence pairs. + */ + public Builder sentencePairProvider(LabeledPairSentenceProvider sentencePairProvider) { + this.sentencePairProvider = sentencePairProvider; + return this; + } + /** * Specify what arrays should be returned. See {@link BertIterator} for more details. */ @@ -591,6 +717,19 @@ public class BertIterator implements MultiDataSetIterator { return this; } + /** + * Append the specified token to the sequences, when doing training on sentence pairs.
+ * Generally "[SEP]" is used + * No token in appended by default. + * + * @param appendToken Token at end of each sentence for pairs of sentences (null: no token will be appended) + * @return + */ + public Builder appendToken(String appendToken) { + this.appendToken = appendToken; + return this; + } + public BertIterator build() { Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed"); Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required"); @@ -598,9 +737,69 @@ public class BertIterator implements MultiDataSetIterator { Preconditions.checkState(task != Task.UNSUPERVISED || masker != null, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method"); Preconditions.checkState(task != Task.UNSUPERVISED || unsupervisedLabelFormat != null, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method"); Preconditions.checkState(task != Task.UNSUPERVISED || maskToken != null, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified"); - + if (sentencePairProvider != null) { + Preconditions.checkState(task == Task.SEQ_CLASSIFICATION, "Currently only supervised sequence classification is set up with sentence pairs. \".task(BertIterator.Task.SEQ_CLASSIFICATION)\" is required with a sentence pair provider"); + Preconditions.checkState(featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID, "Currently only supervised sequence classification is set up with sentence pairs. \".featureArrays(FeatureArrays.INDICES_MASK_SEGMENTID)\" is required with a sentence pair provider"); + Preconditions.checkState(lengthHandling == LengthHandling.FIXED_LENGTH, "Currently only fixed length is supported for sentence pairs. \".lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxLength)\" is required with a sentence pair provider"); + Preconditions.checkState(sentencePairProvider != null, "Provide either a sentence provider or a sentence pair provider. Both cannot be non null"); + } + if (appendToken != null) { + Preconditions.checkState(sentencePairProvider != null, "Tokens are only appended with sentence pairs. Sentence pair provider is not set. Set sentence pair provider."); + } return new BertIterator(this); } } + private static class SentencePairListProcessed { + private int listLength = 0; + + @Getter + private long[] segIdOnesFrom; + private int cursor = 0; + private SentenceListProcessed sentenceListProcessed; + + private SentencePairListProcessed(int listLength) { + this.listLength = listLength; + segIdOnesFrom = new long[listLength]; + sentenceListProcessed = new SentenceListProcessed(listLength); + } + + private void addProcessedToList(long segIdIdx, Pair, String> tokenizedSentencePairAndLabel) { + segIdOnesFrom[cursor] = segIdIdx; + sentenceListProcessed.addProcessedToList(tokenizedSentencePairAndLabel); + cursor++; + } + + private void setMaxL(int maxL) { + sentenceListProcessed.setMaxL(maxL); + } + + private int getMaxL() { + return sentenceListProcessed.getMaxL(); + } + + private List, String>> getTokensAndLabelList() { + return sentenceListProcessed.getTokensAndLabelList(); + } + } + + private static class SentenceListProcessed { + private int listLength; + + @Getter + @Setter + private int maxL; + + @Getter + private List, String>> tokensAndLabelList; + + private SentenceListProcessed(int listLength) { + this.listLength = listLength; + tokensAndLabelList = new ArrayList<>(listLength); + } + + private void addProcessedToList(Pair, String> tokenizedSentenceAndLabel) { + tokensAndLabelList.add(tokenizedSentenceAndLabel); + } + } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/LabeledPairSentenceProvider.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/LabeledPairSentenceProvider.java new file mode 100644 index 000000000..ee68477ee --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/LabeledPairSentenceProvider.java @@ -0,0 +1,60 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.iterator; + +import org.nd4j.linalg.primitives.Triple; + +import java.util.List; + +/** + * LabeledPairSentenceProvider: a simple iterator interface over a pair of sentences/documents that have a label.
+ */ +public interface LabeledPairSentenceProvider { + + /** + * Are there more sentences/documents available? + */ + boolean hasNext(); + + /** + * @return Triple: two sentence/document texts and label + */ + Triple nextSentencePair(); + + /** + * Reset the iterator - including shuffling the order, if necessary/appropriate + */ + void reset(); + + /** + * Return the total number of sentences, or -1 if not available + */ + int totalNumSentences(); + + /** + * Return the list of labels - this also defines the class/integer label assignment order + */ + List allLabels(); + + /** + * Equivalent to allLabels().size() + */ + int numLabelClasses(); + +} + + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledPairSentenceProvider.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledPairSentenceProvider.java new file mode 100644 index 000000000..c3c752bed --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledPairSentenceProvider.java @@ -0,0 +1,135 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.iterator.provider; + +import lombok.NonNull; +import org.deeplearning4j.iterator.LabeledPairSentenceProvider; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.primitives.Triple; +import org.nd4j.linalg.util.MathUtils; + +import java.util.*; + +/** + * Iterate over a pair of sentences/documents, + * where the sentences and labels are provided in lists. + * + */ +public class CollectionLabeledPairSentenceProvider implements LabeledPairSentenceProvider { + + private final List sentenceL; + private final List sentenceR; + private final List labels; + private final Random rng; + private final int[] order; + private final List allLabels; + + private int cursor = 0; + + /** + * Lists containing sentences to iterate over with a third for labels + * Sentences in the same position in the first two lists are considered a pair + * @param sentenceL + * @param sentenceR + * @param labelsForSentences + */ + public CollectionLabeledPairSentenceProvider(@NonNull List sentenceL, @NonNull List sentenceR, + @NonNull List labelsForSentences) { + this(sentenceL, sentenceR, labelsForSentences, new Random()); + } + + /** + * Lists containing sentences to iterate over with a third for labels + * Sentences in the same position in the first two lists are considered a pair + * @param sentenceL + * @param sentenceR + * @param labelsForSentences + * @param rng If null, list order is not shuffled + */ + public CollectionLabeledPairSentenceProvider(@NonNull List sentenceL, List sentenceR, @NonNull List labelsForSentences, + Random rng) { + if (sentenceR.size() != sentenceL.size()) { + throw new IllegalArgumentException("Sentence lists must be same size (first list size: " + + sentenceL.size() + ", second list size: " + sentenceR.size() + ")"); + } + if (sentenceR.size() != labelsForSentences.size()) { + throw new IllegalArgumentException("Sentence pairs and labels must be same size (sentence pair size: " + + sentenceR.size() + ", labels size: " + labelsForSentences.size() + ")"); + } + + this.sentenceL = sentenceL; + this.sentenceR = sentenceR; + this.labels = labelsForSentences; + this.rng = rng; + if (rng == null) { + order = null; + } else { + order = new int[sentenceR.size()]; + for (int i = 0; i < sentenceR.size(); i++) { + order[i] = i; + } + + MathUtils.shuffleArray(order, rng); + } + + //Collect set of unique labels for all sentences + Set uniqueLabels = new HashSet<>(labelsForSentences); + allLabels = new ArrayList<>(uniqueLabels); + Collections.sort(allLabels); + } + + @Override + public boolean hasNext() { + return cursor < sentenceR.size(); + } + + @Override + public Triple nextSentencePair() { + Preconditions.checkState(hasNext(),"No next element available"); + int idx; + if (rng == null) { + idx = cursor++; + } else { + idx = order[cursor++]; + } + return new Triple<>(sentenceL.get(idx), sentenceR.get(idx), labels.get(idx)); + } + + @Override + public void reset() { + cursor = 0; + if (rng != null) { + MathUtils.shuffleArray(order, rng); + } + } + + @Override + public int totalNumSentences() { + return sentenceR.size(); + } + + @Override + public List allLabels() { + return allLabels; + } + + @Override + public int numLabelClasses() { + return allLabels.size(); + } +} + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledSentenceProvider.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledSentenceProvider.java index 3dbaa7db8..e6d65b48b 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledSentenceProvider.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledSentenceProvider.java @@ -18,6 +18,7 @@ package org.deeplearning4j.iterator.provider; import lombok.NonNull; import org.deeplearning4j.iterator.LabeledSentenceProvider; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.MathUtils; @@ -40,15 +41,15 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide private int cursor = 0; public CollectionLabeledSentenceProvider(@NonNull List sentences, - @NonNull List labelsForSentences) { + @NonNull List labelsForSentences) { this(sentences, labelsForSentences, new Random()); } public CollectionLabeledSentenceProvider(@NonNull List sentences, @NonNull List labelsForSentences, - Random rng) { + Random rng) { if (sentences.size() != labelsForSentences.size()) { throw new IllegalArgumentException("Sentences and labels must be same size (sentences size: " - + sentences.size() + ", labels size: " + labelsForSentences.size() + ")"); + + sentences.size() + ", labels size: " + labelsForSentences.size() + ")"); } this.sentences = sentences; @@ -66,10 +67,7 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide } //Collect set of unique labels for all sentences - Set uniqueLabels = new HashSet<>(); - for (String s : labelsForSentences) { - uniqueLabels.add(s); - } + Set uniqueLabels = new HashSet<>(labelsForSentences); allLabels = new ArrayList<>(uniqueLabels); Collections.sort(allLabels); } @@ -81,6 +79,7 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide @Override public Pair nextSentence() { + Preconditions.checkState(hasNext(), "No next element available"); int idx; if (rng == null) { idx = cursor++; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java index d4be5e352..a6716ba40 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -27,6 +28,7 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.primitives.Pair; +import org.nd4j.linalg.primitives.Triple; import org.nd4j.resources.Resources; import java.io.File; @@ -43,7 +45,8 @@ public class TestBertIterator extends BaseDL4JTest { private File pathToVocab = Resources.asFile("other/vocab.txt"); private static Charset c = StandardCharsets.UTF_8; - public TestBertIterator() throws IOException{ } + public TestBertIterator() throws IOException { + } @Test(timeout = 20000L) public void testBertSequenceClassification() throws Exception { @@ -74,8 +77,8 @@ public class TestBertIterator extends BaseDL4JTest { INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); List tokens = t.create(toTokenize1).getTokens(); - Map m = t.getVocab(); - for( int i=0; i m = t.getVocab(); + for (int i = 0; i < tokens.size(); i++) { int idx = m.get(tokens.get(i)); expEx0.putScalar(0, i, idx); expM0.putScalar(0, i, 1); @@ -84,9 +87,9 @@ public class TestBertIterator extends BaseDL4JTest { INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); List tokens2 = t.create(toTokenize2).getTokens(); - for( int i=0; i tokens = t.create(toTokenize1).getTokens(); - Map m = t.getVocab(); - for( int i=0; i m = t.getVocab(); + for (int i = 0; i < tokens.size(); i++) { int idx = m.get(tokens.get(i)); expEx0.putScalar(0, i, idx); expM0.putScalar(0, i, 1); @@ -178,9 +184,10 @@ public class TestBertIterator extends BaseDL4JTest { INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); List tokens2 = t.create(toTokenize2).getTokens(); - for( int i=0; i forInference = new ArrayList<>(); forInference.add(toTokenize1); forInference.add(toTokenize2); + forInference.add(toTokenize3); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); List tokens = t.create(toTokenize1).getTokens(); - Map m = t.getVocab(); - for( int i=0; i m = t.getVocab(); + for (int i = 0; i < tokens.size(); i++) { int idx = m.get(tokens.get(i)); expEx0.putScalar(0, i, idx); expM0.putScalar(0, i, 1); @@ -253,9 +262,9 @@ public class TestBertIterator extends BaseDL4JTest { INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); List tokens2 = t.create(toTokenize2).getTokens(); - for( int i=0; i tokens3 = t.create(toTokenize3).getTokens(); + for (int i = 0; i < tokens3.size(); i++) { + String token = tokens3.get(i); + if (!m.containsKey(token)) { + throw new IllegalStateException("Unknown token: \"" + token + "\""); + } + int idx = m.get(token); + expEx3.putScalar(0, i, idx); + expM3.putScalar(0, i, 1); + } - INDArray expF = Nd4j.vstack(expEx0, expEx1, zeros); - INDArray expM = Nd4j.vstack(expM0, expM1, zeros); - INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {0, 0}, {0, 0}}); + INDArray zeros = Nd4j.create(DataType.INT, 1, 16); + INDArray expF = Nd4j.vstack(expEx0, expEx1, expEx3, zeros); + INDArray expM = Nd4j.vstack(expM0, expM1, expM3, zeros); + INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {1, 0}, {0, 0}}); INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1); expLM.putScalar(0, 0, 1); expLM.putScalar(1, 0, 1); + expLM.putScalar(2, 0, 1); //-------------------------------------------------------------- @@ -305,9 +327,234 @@ public class TestBertIterator extends BaseDL4JTest { assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); } + @Test + public void testSentencePairsSingle() throws IOException { + String shortSent = "I saw a girl with a telescope."; + String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + boolean prependAppend; + BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + int shortL = t.create(shortSent).countTokens(); + int longL = t.create(longSent).countTokens(); + + Triple multiDataSetTriple; + MultiDataSet shortLongPair, shortSentence, longSentence; + + // check for pair max length exactly equal to sum of lengths - pop neither no padding + // should be the same as hstack with segment ids 1 for second sentence + prependAppend = true; + multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend); + shortLongPair = multiDataSetTriple.getFirst(); + shortSentence = multiDataSetTriple.getSecond(); + longSentence = multiDataSetTriple.getThird(); + assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); + longSentence.getFeatures(1).addi(1); + assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); + assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + + //check for pair max length greater than sum of lengths - pop neither with padding + // features should be the same as hstack of shorter and longer padded with prepend/append + // segment id should 1 only in the longer for part of the length of the sentence + prependAppend = true; + multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend); + shortLongPair = multiDataSetTriple.getFirst(); + shortSentence = multiDataSetTriple.getSecond(); + longSentence = multiDataSetTriple.getThird(); + assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); + longSentence.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part + assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); + assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + + //check for pair max length less than shorter sentence - pop both + //should be the same as hstack with segment ids 1 for second sentence if no prepend/append + int maxL = shortL - 2; + prependAppend = false; + multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend); + shortLongPair = multiDataSetTriple.getFirst(); + shortSentence = multiDataSetTriple.getSecond(); + longSentence = multiDataSetTriple.getThird(); + assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); + longSentence.getFeatures(1).addi(1); + assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); + assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + } + + @Test + public void testSentencePairsUnequalLengths() throws IOException { + //check for pop only longer (i.e between longer and longer + shorter), first row pop from second sentence, next row pop from first sentence, nothing to pop in the third row + //should be identical to hstack if there is no append, prepend + //batch size is 2 + int mbS = 4; + String shortSent = "I saw a girl with a telescope."; + String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + String sent1 = "Goodnight noises everywhere"; //shorter than shortSent - no popping + String sent2 = "Goodnight moon"; //shorter than shortSent - no popping + BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + int shortL = t.create(shortSent).countTokens(); + int longL = t.create(longSent).countTokens(); + int sent1L = t.create(sent1).countTokens(); + int sent2L = t.create(sent2).countTokens(); + //won't check 2*shortL + 1 because this will always pop on the left + for (int maxL = longL + shortL - 1; maxL > 2 * shortL; maxL--) { + MultiDataSet leftMDS = BertIterator.builder() + .tokenizer(t) + .minibatchSize(mbS) + .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) + .vocabMap(t.getVocab()) + .task(BertIterator.Task.SEQ_CLASSIFICATION) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either + .sentenceProvider(new TestSentenceProvider()) + .padMinibatches(true) + .build().next(); + + MultiDataSet rightMDS = BertIterator.builder() + .tokenizer(t) + .minibatchSize(mbS) + .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) + .vocabMap(t.getVocab()) + .task(BertIterator.Task.SEQ_CLASSIFICATION) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either + .sentenceProvider(new TestSentenceProvider(true)) + .padMinibatches(true) + .build().next(); + + MultiDataSet pairMDS = BertIterator.builder() + .tokenizer(t) + .minibatchSize(mbS) + .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) + .vocabMap(t.getVocab()) + .task(BertIterator.Task.SEQ_CLASSIFICATION) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) //random big num guaranteed to be longer than either + .sentencePairProvider(new TestSentencePairProvider()) + .padMinibatches(true) + .build().next(); + + //Left sentences here are {{shortSent}, + // {longSent}, + // {Sent1}} + //Right sentences here are {{longSent}, + // {shortSent}, + // {Sent2}} + //The sentence pairs here are {{shortSent,longSent}, + // {longSent,shortSent} + // {Sent1, Sent2}} + + //CHECK FEATURES + INDArray combinedFeat = Nd4j.create(DataType.INT,mbS,maxL); + //left side + INDArray leftFeatures = leftMDS.getFeatures(0); + INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL)); + INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL)); + INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0,sent1L)); + //right side + INDArray rightFeatures = rightMDS.getFeatures(0); + INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL)); + INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL)); + INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0,sent2L)); + //expected pair + combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat,topRSentFeat)); + combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat,midRSentFeat)); + combinedFeat.getRow(2).get(NDArrayIndex.interval(0,sent1L+sent2L)).addi(Nd4j.hstack(bottomLSentFeat,bottomRSentFeat)); + + assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]); + assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape()); + assertEquals(combinedFeat, pairMDS.getFeatures(0)); + + //CHECK SEGMENT ID + INDArray combinedFetSeg = Nd4j.create(DataType.INT, mbS, maxL); + combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1); + combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1); + combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L+sent2L)).addi(1); + assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape()); + assertEquals(maxL, combinedFetSeg.shape()[1]); + assertEquals(combinedFetSeg, pairMDS.getFeatures(1)); + } + } + + @Test + public void testSentencePairFeaturizer() throws IOException { + String shortSent = "I saw a girl with a telescope."; + String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + List> listSentencePair = new ArrayList<>(); + listSentencePair.add(new Pair<>(shortSent, longSent)); + listSentencePair.add(new Pair<>(longSent, shortSent)); + BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + BertIterator b = BertIterator.builder() + .tokenizer(t) + .minibatchSize(2) + .padMinibatches(true) + .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) + .vocabMap(t.getVocab()) + .task(BertIterator.Task.SEQ_CLASSIFICATION) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128) + .sentencePairProvider(new TestSentencePairProvider()) + .prependToken("[CLS]") + .appendToken("[SEP]") + .build(); + MultiDataSet mds = b.next(); + INDArray[] featuresArr = mds.getFeatures(); + INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays(); + + Pair p = b.featurizeSentencePairs(listSentencePair); + assertEquals(p.getFirst().length, 2); + assertEquals(featuresArr[0], p.getFirst()[0]); + assertEquals(featuresArr[1], p.getFirst()[1]); + //assertEquals(p.getSecond().length, 2); + assertEquals(featuresMaskArr[0], p.getSecond()[0]); + //assertEquals(featuresMaskArr[1], p.getSecond()[1]); + } + + /** + * Returns three multidatasets from bert iterator based on given max lengths and whether to prepend/append + * Idea is the sentence pair dataset can be constructed from the single sentence datasets + * First one is constructed from a sentence pair "I saw a girl with a telescope." & "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum" + * Second one is constructed from the left of the sentence pair i.e "I saw a girl with a telescope." + * Third one is constructed from the right of the sentence pair i.e "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum" + */ + private Triple generateMultiDataSets(Triple maxLengths, boolean prependAppend) throws IOException { + BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + int maxforPair = maxLengths.getFirst(); + int maxPartOne = maxLengths.getSecond(); + int maxPartTwo = maxLengths.getThird(); + BertIterator.Builder commonBuilder; + commonBuilder = BertIterator.builder() + .tokenizer(t) + .minibatchSize(1) + .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) + .vocabMap(t.getVocab()) + .task(BertIterator.Task.SEQ_CLASSIFICATION); + BertIterator shortLongPairFirstIter = commonBuilder + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair) + .sentencePairProvider(new TestSentencePairProvider()) + .prependToken(prependAppend ? "[CLS]" : null) + .appendToken(prependAppend ? "[SEP]" : null) + .build(); + BertIterator shortFirstIter = commonBuilder + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne) + .sentenceProvider(new TestSentenceProvider()) + .prependToken(prependAppend ? "[CLS]" : null) + .appendToken(prependAppend ? "[SEP]" : null) + .build(); + BertIterator longFirstIter = commonBuilder + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo) + .sentenceProvider(new TestSentenceProvider(true)) + .prependToken(null) + .appendToken(prependAppend ? "[SEP]" : null) + .build(); + return new Triple<>(shortLongPairFirstIter.next(), shortFirstIter.next(), longFirstIter.next()); + } + private static class TestSentenceProvider implements LabeledSentenceProvider { private int pos = 0; + private boolean invert; + + private TestSentenceProvider() { + this.invert = false; + } + + private TestSentenceProvider(boolean invert) { + this.invert = invert; + } @Override public boolean hasNext() { @@ -317,10 +564,20 @@ public class TestBertIterator extends BaseDL4JTest { @Override public Pair nextSentence() { Preconditions.checkState(hasNext()); - if(pos++ == 0){ - return new Pair<>("I saw a girl with a telescope.", "positive"); - } else { + if (pos == 0) { + pos++; + if (!invert) return new Pair<>("I saw a girl with a telescope.", "positive"); return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative"); + } else { + if (pos == 1) { + pos++; + if (!invert) return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative"); + return new Pair<>("I saw a girl with a telescope.", "positive"); + } + pos++; + if (!invert) + return new Pair<>("Goodnight noises everywhere", "positive"); + return new Pair<>("Goodnight moon", "positive"); } } @@ -331,8 +588,54 @@ public class TestBertIterator extends BaseDL4JTest { @Override public int totalNumSentences() { + return 3; + } + + @Override + public List allLabels() { + return Arrays.asList("positive", "negative"); + } + + @Override + public int numLabelClasses() { return 2; } + } + + private static class TestSentencePairProvider implements LabeledPairSentenceProvider { + + private int pos = 0; + + @Override + public boolean hasNext() { + return pos < totalNumSentences(); + } + + @Override + public Triple nextSentencePair() { + Preconditions.checkState(hasNext()); + if (pos == 0) { + pos++; + return new Triple<>("I saw a girl with a telescope.", "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "positive"); + } else { + if (pos == 1) { + pos++; + return new Triple<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "I saw a girl with a telescope.", "negative"); + } + pos++; + return new Triple<>("Goodnight noises everywhere", "Goodnight moon", "positive"); + } + } + + @Override + public void reset() { + pos = 0; + } + + @Override + public int totalNumSentences() { + return 3; + } @Override public List allLabels() {