diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java
index b5ca6c91a..40c43113c 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java
@@ -34,6 +34,7 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
+import org.nd4j.linalg.primitives.Triple;
import java.util.ArrayList;
import java.util.Arrays;
@@ -85,10 +86,20 @@ import java.util.Map;
*
* {@code
* BertIterator b;
+ * Pair featuresAndMask;
+ * INDArray[] features;
+ * INDArray[] featureMasks;
+ *
+ * //With sentences
* List forInference;
- * Pair featuresAndMask = b.featurizeSentences(forInference);
- * INDArray[] features = featuresAndMask.getFirst();
- * INDArray[] featureMasks = featuresAndMask.getSecond();
+ * featuresAndMask = b.featurizeSentences(forInference);
+ *
+ * //OR with sentence pairs
+ * List> forInferencePair};
+ * featuresAndMask = b.featurizeSentencePairs(forInference);
+ *
+ * features = featuresAndMask.getFirst();
+ * featureMasks = featuresAndMask.getSecond();
* }
*
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.
@@ -135,6 +146,7 @@ public class BertIterator implements MultiDataSetIterator {
@Setter
protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null;
+ protected LabeledPairSentenceProvider sentencePairProvider = null;
protected LengthHandling lengthHandling;
protected FeatureArrays featureArrays;
protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects?
@@ -142,6 +154,7 @@ public class BertIterator implements MultiDataSetIterator {
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
protected String maskToken;
protected String prependToken;
+ protected String appendToken;
protected List vocabKeysAsList;
@@ -154,6 +167,7 @@ public class BertIterator implements MultiDataSetIterator {
this.padMinibatches = b.padMinibatches;
this.preProcessor = b.preProcessor;
this.sentenceProvider = b.sentenceProvider;
+ this.sentencePairProvider = b.sentencePairProvider;
this.lengthHandling = b.lengthHandling;
this.featureArrays = b.featureArrays;
this.vocabMap = b.vocabMap;
@@ -161,11 +175,14 @@ public class BertIterator implements MultiDataSetIterator {
this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
this.maskToken = b.maskToken;
this.prependToken = b.prependToken;
+ this.appendToken = b.appendToken;
}
@Override
public boolean hasNext() {
- return sentenceProvider.hasNext();
+ if (sentenceProvider != null)
+ return sentenceProvider.hasNext();
+ return sentencePairProvider.hasNext();
}
@Override
@@ -181,29 +198,38 @@ public class BertIterator implements MultiDataSetIterator {
@Override
public MultiDataSet next(int num) {
Preconditions.checkState(hasNext(), "No next element available");
-
- List> list = new ArrayList<>(num);
+ List, String>> tokensAndLabelList;
int mbSize = 0;
+ int outLength;
+ long[] segIdOnesFrom = null;
if (sentenceProvider != null) {
+ List> list = new ArrayList<>(num);
while (sentenceProvider.hasNext() && mbSize++ < num) {
list.add(sentenceProvider.nextSentence());
}
+ SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(list);
+ tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
+ outLength = sentenceListProcessed.getMaxL();
+ } else if (sentencePairProvider != null) {
+ List> listPairs = new ArrayList<>(num);
+ while (sentencePairProvider.hasNext() && mbSize++ < num) {
+ listPairs.add(sentencePairProvider.nextSentencePair());
+ }
+ SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(listPairs);
+ tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
+ outLength = sentencePairListProcessed.getMaxL();
+ segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
} else {
//TODO - other types of iterators...
throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
}
-
- Pair, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list);
- List, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
- int outLength = outLTokenizedSentencesPair.getLeft();
-
- Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength);
+ Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
INDArray[] featureArray = featuresAndMaskArraysPair.getFirst();
INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond();
- Pair labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength);
+ Pair labelsAndMaskArraysPair = convertMiniBatchLabels(tokensAndLabelList, featureArray, outLength);
INDArray[] labelArray = labelsAndMaskArraysPair.getFirst();
INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond();
@@ -224,32 +250,59 @@ public class BertIterator implements MultiDataSetIterator {
public Pair featurizeSentences(List listOnlySentences) {
List> sentencesWithNullLabel = addDummyLabel(listOnlySentences);
+ SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(sentencesWithNullLabel);
+ List, String>> tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
+ int outLength = sentenceListProcessed.getMaxL();
- Pair, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel);
- List, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
- int outLength = outLTokenizedSentencesPair.getLeft();
-
- Pair featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength);
if (preProcessor != null) {
+ Pair featureFeatureMasks = convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null);
preProcessor.preProcess(dummyMDS);
- return new Pair(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
+ return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
}
- return convertMiniBatchFeatures(tokenizedSentences, outLength);
+ return convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
}
- private Pair convertMiniBatchFeatures(List, String>> tokenizedSentences, int outLength) {
- int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size();
+ /**
+ * For use during inference. Will convert a given pair of a list of sentences to features and feature masks as appropriate.
+ *
+ * @param listOnlySentencePairs
+ * @return Pair of INDArrays[], first element is feature arrays and the second is the masks array
+ */
+ public Pair featurizeSentencePairs(List> listOnlySentencePairs) {
+ Preconditions.checkState(sentencePairProvider != null, "The featurizeSentencePairs method is meant for inference with sentence pairs. Use only when the sentence pair provider is set (i.e not null).");
+
+ List> sentencePairsWithNullLabel = addDummyLabelForPairs(listOnlySentencePairs);
+ SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(sentencePairsWithNullLabel);
+ List, String>> tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
+ int outLength = sentencePairListProcessed.getMaxL();
+ long[] segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
+ if (preProcessor != null) {
+ Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
+ MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featuresAndMaskArraysPair.getFirst(), null, featuresAndMaskArraysPair.getSecond(), null);
+ preProcessor.preProcess(dummyMDS);
+ return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
+ }
+ return convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
+ }
+
+ private Pair convertMiniBatchFeatures(List, String>> tokensAndLabelList, int outLength, long[] segIdOnesFrom) {
+ int mbPadded = padMinibatches ? minibatchSize : tokensAndLabelList.size();
int[][] outIdxs = new int[mbPadded][outLength];
int[][] outMask = new int[mbPadded][outLength];
- for (int i = 0; i < tokenizedSentences.size(); i++) {
- Pair, String> p = tokenizedSentences.get(i);
+ int[][] outSegmentId = null;
+ if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID)
+ outSegmentId = new int[mbPadded][outLength];
+ for (int i = 0; i < tokensAndLabelList.size(); i++) {
+ Pair, String> p = tokensAndLabelList.get(i);
List t = p.getFirst();
for (int j = 0; j < outLength && j < t.size(); j++) {
Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j));
int idx = vocabMap.get(t.get(j));
outIdxs[i][j] = idx;
outMask[i][j] = 1;
+ if (segIdOnesFrom != null && j >= segIdOnesFrom[i])
+ outSegmentId[i][j] = 1;
}
}
@@ -260,8 +313,7 @@ public class BertIterator implements MultiDataSetIterator {
INDArray[] f;
INDArray[] fm;
if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
- //For now: always segment index 0 (only single s sequence input supported)
- outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength);
+ outSegmentIdArr = Nd4j.createFromArray(outSegmentId);
f = new INDArray[]{outIdxsArr, outSegmentIdArr};
fm = new INDArray[]{outMaskArr, null};
} else {
@@ -271,16 +323,15 @@ public class BertIterator implements MultiDataSetIterator {
return new Pair<>(f, fm);
}
- private Pair, String>>> tokenizeMiniBatch(List> list) {
+ private SentenceListProcessed tokenizeMiniBatch(List> list) {
//Get and tokenize the sentences for this minibatch
- List, String>> tokenizedSentences = new ArrayList<>(list.size());
+ SentenceListProcessed sentenceListProcessed = new SentenceListProcessed(list.size());
int longestSeq = -1;
for (Pair p : list) {
List tokens = tokenizeSentence(p.getFirst());
- tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
+ sentenceListProcessed.addProcessedToList(new Pair<>(tokens, p.getSecond()));
longestSeq = Math.max(longestSeq, tokens.size());
}
-
//Determine output array length...
int outLength;
switch (lengthHandling) {
@@ -296,7 +347,52 @@ public class BertIterator implements MultiDataSetIterator {
default:
throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
}
- return new Pair<>(outLength, tokenizedSentences);
+ sentenceListProcessed.setMaxL(outLength);
+ return sentenceListProcessed;
+ }
+
+ private SentencePairListProcessed tokenizePairsMiniBatch(List> listPairs) {
+ SentencePairListProcessed sentencePairListProcessed = new SentencePairListProcessed(listPairs.size());
+ for (Triple t : listPairs) {
+ List tokensL = tokenizeSentence(t.getFirst(), true);
+ List tokensR = tokenizeSentence(t.getSecond(), true);
+ List tokens = new ArrayList<>(maxTokens);
+ int maxLength = maxTokens;
+ if (prependToken != null)
+ maxLength--;
+ if (appendToken != null)
+ maxLength -= 2;
+ if (tokensL.size() + tokensR.size() > maxLength) {
+ boolean shortOnL = tokensL.size() < tokensR.size();
+ int shortSize = Math.min(tokensL.size(), tokensR.size());
+ if (shortSize > maxLength / 2) {
+ //both lists need to be sliced
+ tokensL.subList(maxLength / 2, tokensL.size()).clear(); //if maxsize/2 is odd pop extra on L side to match implementation in TF
+ tokensR.subList(maxLength - maxLength / 2, tokensR.size()).clear();
+ } else {
+ //slice longer list
+ if (shortOnL) {
+ //longer on R - slice R
+ tokensR.subList(maxLength - tokensL.size(), tokensR.size()).clear();
+ } else {
+ //longer on L - slice L
+ tokensL.subList(maxLength - tokensR.size(), tokensL.size()).clear();
+ }
+ }
+ }
+ if (prependToken != null)
+ tokens.add(prependToken);
+ tokens.addAll(tokensL);
+ if (appendToken != null)
+ tokens.add(appendToken);
+ int segIdOnesFrom = tokens.size();
+ tokens.addAll(tokensR);
+ if (appendToken != null)
+ tokens.add(appendToken);
+ sentencePairListProcessed.addProcessedToList(segIdOnesFrom, new Pair<>(tokens, t.getThird()));
+ }
+ sentencePairListProcessed.setMaxL(maxTokens);
+ return sentencePairListProcessed;
}
private Pair convertMiniBatchLabels(List, String>> tokenizedSentences, INDArray[] featureArray, int outLength) {
@@ -316,6 +412,14 @@ public class BertIterator implements MultiDataSetIterator {
classLabels[i] = labels.indexOf(lbl);
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
}
+ } else if (sentencePairProvider != null) {
+ numClasses = sentencePairProvider.numLabelClasses();
+ List labels = sentencePairProvider.allLabels();
+ for (int i = 0; i < mbSize; i++) {
+ String lbl = tokenizedSentences.get(i).getRight();
+ classLabels[i] = labels.indexOf(lbl);
+ Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
+ }
} else {
throw new RuntimeException();
}
@@ -392,16 +496,22 @@ public class BertIterator implements MultiDataSetIterator {
}
private List tokenizeSentence(String sentence) {
+ return tokenizeSentence(sentence, false);
+ }
+
+ private List tokenizeSentence(String sentence, boolean ignorePrependAppend) {
Tokenizer t = tokenizerFactory.create(sentence);
List tokens = new ArrayList<>();
- if (prependToken != null)
+ if (prependToken != null && !ignorePrependAppend)
tokens.add(prependToken);
while (t.hasMoreTokens()) {
String token = t.nextToken();
tokens.add(token);
}
+ if (appendToken != null && !ignorePrependAppend)
+ tokens.add(appendToken);
return tokens;
}
@@ -414,6 +524,13 @@ public class BertIterator implements MultiDataSetIterator {
return list;
}
+ private List> addDummyLabelForPairs(List> listOnlySentencePairs) {
+ List> list = new ArrayList<>(listOnlySentencePairs.size());
+ for (Pair p : listOnlySentencePairs) {
+ list.add(new Triple(p.getFirst(), p.getSecond(), null));
+ }
+ return list;
+ }
@Override
public boolean resetSupported() {
@@ -446,12 +563,14 @@ public class BertIterator implements MultiDataSetIterator {
protected boolean padMinibatches = false;
protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null;
+ protected LabeledPairSentenceProvider sentencePairProvider = null;
protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects?
protected BertSequenceMasker masker = new BertMaskedLMMasker();
protected UnsupervisedLabelFormat unsupervisedLabelFormat;
protected String maskToken;
protected String prependToken;
+ protected String appendToken;
/**
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
@@ -519,14 +638,21 @@ public class BertIterator implements MultiDataSetIterator {
}
/**
- * Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised
- * use case, the labels will be ignored.
+ * Specify the source of the data for classification.
*/
public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) {
this.sentenceProvider = sentenceProvider;
return this;
}
+ /**
+ * Specify the source of the data for classification on sentence pairs.
+ */
+ public Builder sentencePairProvider(LabeledPairSentenceProvider sentencePairProvider) {
+ this.sentencePairProvider = sentencePairProvider;
+ return this;
+ }
+
/**
* Specify what arrays should be returned. See {@link BertIterator} for more details.
*/
@@ -591,6 +717,19 @@ public class BertIterator implements MultiDataSetIterator {
return this;
}
+ /**
+ * Append the specified token to the sequences, when doing training on sentence pairs.
+ * Generally "[SEP]" is used
+ * No token in appended by default.
+ *
+ * @param appendToken Token at end of each sentence for pairs of sentences (null: no token will be appended)
+ * @return
+ */
+ public Builder appendToken(String appendToken) {
+ this.appendToken = appendToken;
+ return this;
+ }
+
public BertIterator build() {
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
@@ -598,9 +737,69 @@ public class BertIterator implements MultiDataSetIterator {
Preconditions.checkState(task != Task.UNSUPERVISED || masker != null, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method");
Preconditions.checkState(task != Task.UNSUPERVISED || unsupervisedLabelFormat != null, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method");
Preconditions.checkState(task != Task.UNSUPERVISED || maskToken != null, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified");
-
+ if (sentencePairProvider != null) {
+ Preconditions.checkState(task == Task.SEQ_CLASSIFICATION, "Currently only supervised sequence classification is set up with sentence pairs. \".task(BertIterator.Task.SEQ_CLASSIFICATION)\" is required with a sentence pair provider");
+ Preconditions.checkState(featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID, "Currently only supervised sequence classification is set up with sentence pairs. \".featureArrays(FeatureArrays.INDICES_MASK_SEGMENTID)\" is required with a sentence pair provider");
+ Preconditions.checkState(lengthHandling == LengthHandling.FIXED_LENGTH, "Currently only fixed length is supported for sentence pairs. \".lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxLength)\" is required with a sentence pair provider");
+ Preconditions.checkState(sentencePairProvider != null, "Provide either a sentence provider or a sentence pair provider. Both cannot be non null");
+ }
+ if (appendToken != null) {
+ Preconditions.checkState(sentencePairProvider != null, "Tokens are only appended with sentence pairs. Sentence pair provider is not set. Set sentence pair provider.");
+ }
return new BertIterator(this);
}
}
+ private static class SentencePairListProcessed {
+ private int listLength = 0;
+
+ @Getter
+ private long[] segIdOnesFrom;
+ private int cursor = 0;
+ private SentenceListProcessed sentenceListProcessed;
+
+ private SentencePairListProcessed(int listLength) {
+ this.listLength = listLength;
+ segIdOnesFrom = new long[listLength];
+ sentenceListProcessed = new SentenceListProcessed(listLength);
+ }
+
+ private void addProcessedToList(long segIdIdx, Pair, String> tokenizedSentencePairAndLabel) {
+ segIdOnesFrom[cursor] = segIdIdx;
+ sentenceListProcessed.addProcessedToList(tokenizedSentencePairAndLabel);
+ cursor++;
+ }
+
+ private void setMaxL(int maxL) {
+ sentenceListProcessed.setMaxL(maxL);
+ }
+
+ private int getMaxL() {
+ return sentenceListProcessed.getMaxL();
+ }
+
+ private List, String>> getTokensAndLabelList() {
+ return sentenceListProcessed.getTokensAndLabelList();
+ }
+ }
+
+ private static class SentenceListProcessed {
+ private int listLength;
+
+ @Getter
+ @Setter
+ private int maxL;
+
+ @Getter
+ private List, String>> tokensAndLabelList;
+
+ private SentenceListProcessed(int listLength) {
+ this.listLength = listLength;
+ tokensAndLabelList = new ArrayList<>(listLength);
+ }
+
+ private void addProcessedToList(Pair, String> tokenizedSentenceAndLabel) {
+ tokensAndLabelList.add(tokenizedSentenceAndLabel);
+ }
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/LabeledPairSentenceProvider.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/LabeledPairSentenceProvider.java
new file mode 100644
index 000000000..ee68477ee
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/LabeledPairSentenceProvider.java
@@ -0,0 +1,60 @@
+/*******************************************************************************
+ * Copyright (c) 2019 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.iterator;
+
+import org.nd4j.linalg.primitives.Triple;
+
+import java.util.List;
+
+/**
+ * LabeledPairSentenceProvider: a simple iterator interface over a pair of sentences/documents that have a label.
+ */
+public interface LabeledPairSentenceProvider {
+
+ /**
+ * Are there more sentences/documents available?
+ */
+ boolean hasNext();
+
+ /**
+ * @return Triple: two sentence/document texts and label
+ */
+ Triple nextSentencePair();
+
+ /**
+ * Reset the iterator - including shuffling the order, if necessary/appropriate
+ */
+ void reset();
+
+ /**
+ * Return the total number of sentences, or -1 if not available
+ */
+ int totalNumSentences();
+
+ /**
+ * Return the list of labels - this also defines the class/integer label assignment order
+ */
+ List allLabels();
+
+ /**
+ * Equivalent to allLabels().size()
+ */
+ int numLabelClasses();
+
+}
+
+
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledPairSentenceProvider.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledPairSentenceProvider.java
new file mode 100644
index 000000000..c3c752bed
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledPairSentenceProvider.java
@@ -0,0 +1,135 @@
+/*******************************************************************************
+ * Copyright (c) 2019 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.iterator.provider;
+
+import lombok.NonNull;
+import org.deeplearning4j.iterator.LabeledPairSentenceProvider;
+import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.primitives.Triple;
+import org.nd4j.linalg.util.MathUtils;
+
+import java.util.*;
+
+/**
+ * Iterate over a pair of sentences/documents,
+ * where the sentences and labels are provided in lists.
+ *
+ */
+public class CollectionLabeledPairSentenceProvider implements LabeledPairSentenceProvider {
+
+ private final List sentenceL;
+ private final List sentenceR;
+ private final List labels;
+ private final Random rng;
+ private final int[] order;
+ private final List allLabels;
+
+ private int cursor = 0;
+
+ /**
+ * Lists containing sentences to iterate over with a third for labels
+ * Sentences in the same position in the first two lists are considered a pair
+ * @param sentenceL
+ * @param sentenceR
+ * @param labelsForSentences
+ */
+ public CollectionLabeledPairSentenceProvider(@NonNull List sentenceL, @NonNull List sentenceR,
+ @NonNull List labelsForSentences) {
+ this(sentenceL, sentenceR, labelsForSentences, new Random());
+ }
+
+ /**
+ * Lists containing sentences to iterate over with a third for labels
+ * Sentences in the same position in the first two lists are considered a pair
+ * @param sentenceL
+ * @param sentenceR
+ * @param labelsForSentences
+ * @param rng If null, list order is not shuffled
+ */
+ public CollectionLabeledPairSentenceProvider(@NonNull List sentenceL, List sentenceR, @NonNull List labelsForSentences,
+ Random rng) {
+ if (sentenceR.size() != sentenceL.size()) {
+ throw new IllegalArgumentException("Sentence lists must be same size (first list size: "
+ + sentenceL.size() + ", second list size: " + sentenceR.size() + ")");
+ }
+ if (sentenceR.size() != labelsForSentences.size()) {
+ throw new IllegalArgumentException("Sentence pairs and labels must be same size (sentence pair size: "
+ + sentenceR.size() + ", labels size: " + labelsForSentences.size() + ")");
+ }
+
+ this.sentenceL = sentenceL;
+ this.sentenceR = sentenceR;
+ this.labels = labelsForSentences;
+ this.rng = rng;
+ if (rng == null) {
+ order = null;
+ } else {
+ order = new int[sentenceR.size()];
+ for (int i = 0; i < sentenceR.size(); i++) {
+ order[i] = i;
+ }
+
+ MathUtils.shuffleArray(order, rng);
+ }
+
+ //Collect set of unique labels for all sentences
+ Set uniqueLabels = new HashSet<>(labelsForSentences);
+ allLabels = new ArrayList<>(uniqueLabels);
+ Collections.sort(allLabels);
+ }
+
+ @Override
+ public boolean hasNext() {
+ return cursor < sentenceR.size();
+ }
+
+ @Override
+ public Triple nextSentencePair() {
+ Preconditions.checkState(hasNext(),"No next element available");
+ int idx;
+ if (rng == null) {
+ idx = cursor++;
+ } else {
+ idx = order[cursor++];
+ }
+ return new Triple<>(sentenceL.get(idx), sentenceR.get(idx), labels.get(idx));
+ }
+
+ @Override
+ public void reset() {
+ cursor = 0;
+ if (rng != null) {
+ MathUtils.shuffleArray(order, rng);
+ }
+ }
+
+ @Override
+ public int totalNumSentences() {
+ return sentenceR.size();
+ }
+
+ @Override
+ public List allLabels() {
+ return allLabels;
+ }
+
+ @Override
+ public int numLabelClasses() {
+ return allLabels.size();
+ }
+}
+
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledSentenceProvider.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledSentenceProvider.java
index 3dbaa7db8..e6d65b48b 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledSentenceProvider.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledSentenceProvider.java
@@ -18,6 +18,7 @@ package org.deeplearning4j.iterator.provider;
import lombok.NonNull;
import org.deeplearning4j.iterator.LabeledSentenceProvider;
+import org.nd4j.base.Preconditions;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.MathUtils;
@@ -40,15 +41,15 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide
private int cursor = 0;
public CollectionLabeledSentenceProvider(@NonNull List sentences,
- @NonNull List labelsForSentences) {
+ @NonNull List labelsForSentences) {
this(sentences, labelsForSentences, new Random());
}
public CollectionLabeledSentenceProvider(@NonNull List sentences, @NonNull List labelsForSentences,
- Random rng) {
+ Random rng) {
if (sentences.size() != labelsForSentences.size()) {
throw new IllegalArgumentException("Sentences and labels must be same size (sentences size: "
- + sentences.size() + ", labels size: " + labelsForSentences.size() + ")");
+ + sentences.size() + ", labels size: " + labelsForSentences.size() + ")");
}
this.sentences = sentences;
@@ -66,10 +67,7 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide
}
//Collect set of unique labels for all sentences
- Set uniqueLabels = new HashSet<>();
- for (String s : labelsForSentences) {
- uniqueLabels.add(s);
- }
+ Set uniqueLabels = new HashSet<>(labelsForSentences);
allLabels = new ArrayList<>(uniqueLabels);
Collections.sort(allLabels);
}
@@ -81,6 +79,7 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide
@Override
public Pair nextSentence() {
+ Preconditions.checkState(hasNext(), "No next element available");
int idx;
if (rng == null) {
idx = cursor++;
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java
index d4be5e352..a6716ba40 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java
@@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -27,6 +28,7 @@ import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
+import org.nd4j.linalg.primitives.Triple;
import org.nd4j.resources.Resources;
import java.io.File;
@@ -43,7 +45,8 @@ public class TestBertIterator extends BaseDL4JTest {
private File pathToVocab = Resources.asFile("other/vocab.txt");
private static Charset c = StandardCharsets.UTF_8;
- public TestBertIterator() throws IOException{ }
+ public TestBertIterator() throws IOException {
+ }
@Test(timeout = 20000L)
public void testBertSequenceClassification() throws Exception {
@@ -74,8 +77,8 @@ public class TestBertIterator extends BaseDL4JTest {
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
List tokens = t.create(toTokenize1).getTokens();
- Map m = t.getVocab();
- for( int i=0; i m = t.getVocab();
+ for (int i = 0; i < tokens.size(); i++) {
int idx = m.get(tokens.get(i));
expEx0.putScalar(0, i, idx);
expM0.putScalar(0, i, 1);
@@ -84,9 +87,9 @@ public class TestBertIterator extends BaseDL4JTest {
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
List tokens2 = t.create(toTokenize2).getTokens();
- for( int i=0; i tokens = t.create(toTokenize1).getTokens();
- Map m = t.getVocab();
- for( int i=0; i m = t.getVocab();
+ for (int i = 0; i < tokens.size(); i++) {
int idx = m.get(tokens.get(i));
expEx0.putScalar(0, i, idx);
expM0.putScalar(0, i, 1);
@@ -178,9 +184,10 @@ public class TestBertIterator extends BaseDL4JTest {
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
List tokens2 = t.create(toTokenize2).getTokens();
- for( int i=0; i forInference = new ArrayList<>();
forInference.add(toTokenize1);
forInference.add(toTokenize2);
+ forInference.add(toTokenize3);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
List tokens = t.create(toTokenize1).getTokens();
- Map m = t.getVocab();
- for( int i=0; i m = t.getVocab();
+ for (int i = 0; i < tokens.size(); i++) {
int idx = m.get(tokens.get(i));
expEx0.putScalar(0, i, idx);
expM0.putScalar(0, i, 1);
@@ -253,9 +262,9 @@ public class TestBertIterator extends BaseDL4JTest {
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
List tokens2 = t.create(toTokenize2).getTokens();
- for( int i=0; i tokens3 = t.create(toTokenize3).getTokens();
+ for (int i = 0; i < tokens3.size(); i++) {
+ String token = tokens3.get(i);
+ if (!m.containsKey(token)) {
+ throw new IllegalStateException("Unknown token: \"" + token + "\"");
+ }
+ int idx = m.get(token);
+ expEx3.putScalar(0, i, idx);
+ expM3.putScalar(0, i, 1);
+ }
- INDArray expF = Nd4j.vstack(expEx0, expEx1, zeros);
- INDArray expM = Nd4j.vstack(expM0, expM1, zeros);
- INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {0, 0}, {0, 0}});
+ INDArray zeros = Nd4j.create(DataType.INT, 1, 16);
+ INDArray expF = Nd4j.vstack(expEx0, expEx1, expEx3, zeros);
+ INDArray expM = Nd4j.vstack(expM0, expM1, expM3, zeros);
+ INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {1, 0}, {0, 0}});
INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1);
expLM.putScalar(0, 0, 1);
expLM.putScalar(1, 0, 1);
+ expLM.putScalar(2, 0, 1);
//--------------------------------------------------------------
@@ -305,9 +327,234 @@ public class TestBertIterator extends BaseDL4JTest {
assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
}
+ @Test
+ public void testSentencePairsSingle() throws IOException {
+ String shortSent = "I saw a girl with a telescope.";
+ String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
+ boolean prependAppend;
+ BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
+ int shortL = t.create(shortSent).countTokens();
+ int longL = t.create(longSent).countTokens();
+
+ Triple multiDataSetTriple;
+ MultiDataSet shortLongPair, shortSentence, longSentence;
+
+ // check for pair max length exactly equal to sum of lengths - pop neither no padding
+ // should be the same as hstack with segment ids 1 for second sentence
+ prependAppend = true;
+ multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend);
+ shortLongPair = multiDataSetTriple.getFirst();
+ shortSentence = multiDataSetTriple.getSecond();
+ longSentence = multiDataSetTriple.getThird();
+ assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
+ longSentence.getFeatures(1).addi(1);
+ assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
+ assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
+
+ //check for pair max length greater than sum of lengths - pop neither with padding
+ // features should be the same as hstack of shorter and longer padded with prepend/append
+ // segment id should 1 only in the longer for part of the length of the sentence
+ prependAppend = true;
+ multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend);
+ shortLongPair = multiDataSetTriple.getFirst();
+ shortSentence = multiDataSetTriple.getSecond();
+ longSentence = multiDataSetTriple.getThird();
+ assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
+ longSentence.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part
+ assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
+ assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
+
+ //check for pair max length less than shorter sentence - pop both
+ //should be the same as hstack with segment ids 1 for second sentence if no prepend/append
+ int maxL = shortL - 2;
+ prependAppend = false;
+ multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend);
+ shortLongPair = multiDataSetTriple.getFirst();
+ shortSentence = multiDataSetTriple.getSecond();
+ longSentence = multiDataSetTriple.getThird();
+ assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
+ longSentence.getFeatures(1).addi(1);
+ assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
+ assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
+ }
+
+ @Test
+ public void testSentencePairsUnequalLengths() throws IOException {
+ //check for pop only longer (i.e between longer and longer + shorter), first row pop from second sentence, next row pop from first sentence, nothing to pop in the third row
+ //should be identical to hstack if there is no append, prepend
+ //batch size is 2
+ int mbS = 4;
+ String shortSent = "I saw a girl with a telescope.";
+ String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
+ String sent1 = "Goodnight noises everywhere"; //shorter than shortSent - no popping
+ String sent2 = "Goodnight moon"; //shorter than shortSent - no popping
+ BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
+ int shortL = t.create(shortSent).countTokens();
+ int longL = t.create(longSent).countTokens();
+ int sent1L = t.create(sent1).countTokens();
+ int sent2L = t.create(sent2).countTokens();
+ //won't check 2*shortL + 1 because this will always pop on the left
+ for (int maxL = longL + shortL - 1; maxL > 2 * shortL; maxL--) {
+ MultiDataSet leftMDS = BertIterator.builder()
+ .tokenizer(t)
+ .minibatchSize(mbS)
+ .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
+ .vocabMap(t.getVocab())
+ .task(BertIterator.Task.SEQ_CLASSIFICATION)
+ .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either
+ .sentenceProvider(new TestSentenceProvider())
+ .padMinibatches(true)
+ .build().next();
+
+ MultiDataSet rightMDS = BertIterator.builder()
+ .tokenizer(t)
+ .minibatchSize(mbS)
+ .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
+ .vocabMap(t.getVocab())
+ .task(BertIterator.Task.SEQ_CLASSIFICATION)
+ .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either
+ .sentenceProvider(new TestSentenceProvider(true))
+ .padMinibatches(true)
+ .build().next();
+
+ MultiDataSet pairMDS = BertIterator.builder()
+ .tokenizer(t)
+ .minibatchSize(mbS)
+ .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
+ .vocabMap(t.getVocab())
+ .task(BertIterator.Task.SEQ_CLASSIFICATION)
+ .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) //random big num guaranteed to be longer than either
+ .sentencePairProvider(new TestSentencePairProvider())
+ .padMinibatches(true)
+ .build().next();
+
+ //Left sentences here are {{shortSent},
+ // {longSent},
+ // {Sent1}}
+ //Right sentences here are {{longSent},
+ // {shortSent},
+ // {Sent2}}
+ //The sentence pairs here are {{shortSent,longSent},
+ // {longSent,shortSent}
+ // {Sent1, Sent2}}
+
+ //CHECK FEATURES
+ INDArray combinedFeat = Nd4j.create(DataType.INT,mbS,maxL);
+ //left side
+ INDArray leftFeatures = leftMDS.getFeatures(0);
+ INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL));
+ INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL));
+ INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0,sent1L));
+ //right side
+ INDArray rightFeatures = rightMDS.getFeatures(0);
+ INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL));
+ INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL));
+ INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0,sent2L));
+ //expected pair
+ combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat,topRSentFeat));
+ combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat,midRSentFeat));
+ combinedFeat.getRow(2).get(NDArrayIndex.interval(0,sent1L+sent2L)).addi(Nd4j.hstack(bottomLSentFeat,bottomRSentFeat));
+
+ assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]);
+ assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape());
+ assertEquals(combinedFeat, pairMDS.getFeatures(0));
+
+ //CHECK SEGMENT ID
+ INDArray combinedFetSeg = Nd4j.create(DataType.INT, mbS, maxL);
+ combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1);
+ combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1);
+ combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L+sent2L)).addi(1);
+ assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape());
+ assertEquals(maxL, combinedFetSeg.shape()[1]);
+ assertEquals(combinedFetSeg, pairMDS.getFeatures(1));
+ }
+ }
+
+ @Test
+ public void testSentencePairFeaturizer() throws IOException {
+ String shortSent = "I saw a girl with a telescope.";
+ String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
+ List> listSentencePair = new ArrayList<>();
+ listSentencePair.add(new Pair<>(shortSent, longSent));
+ listSentencePair.add(new Pair<>(longSent, shortSent));
+ BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
+ BertIterator b = BertIterator.builder()
+ .tokenizer(t)
+ .minibatchSize(2)
+ .padMinibatches(true)
+ .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
+ .vocabMap(t.getVocab())
+ .task(BertIterator.Task.SEQ_CLASSIFICATION)
+ .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128)
+ .sentencePairProvider(new TestSentencePairProvider())
+ .prependToken("[CLS]")
+ .appendToken("[SEP]")
+ .build();
+ MultiDataSet mds = b.next();
+ INDArray[] featuresArr = mds.getFeatures();
+ INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays();
+
+ Pair p = b.featurizeSentencePairs(listSentencePair);
+ assertEquals(p.getFirst().length, 2);
+ assertEquals(featuresArr[0], p.getFirst()[0]);
+ assertEquals(featuresArr[1], p.getFirst()[1]);
+ //assertEquals(p.getSecond().length, 2);
+ assertEquals(featuresMaskArr[0], p.getSecond()[0]);
+ //assertEquals(featuresMaskArr[1], p.getSecond()[1]);
+ }
+
+ /**
+ * Returns three multidatasets from bert iterator based on given max lengths and whether to prepend/append
+ * Idea is the sentence pair dataset can be constructed from the single sentence datasets
+ * First one is constructed from a sentence pair "I saw a girl with a telescope." & "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
+ * Second one is constructed from the left of the sentence pair i.e "I saw a girl with a telescope."
+ * Third one is constructed from the right of the sentence pair i.e "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
+ */
+ private Triple generateMultiDataSets(Triple maxLengths, boolean prependAppend) throws IOException {
+ BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
+ int maxforPair = maxLengths.getFirst();
+ int maxPartOne = maxLengths.getSecond();
+ int maxPartTwo = maxLengths.getThird();
+ BertIterator.Builder commonBuilder;
+ commonBuilder = BertIterator.builder()
+ .tokenizer(t)
+ .minibatchSize(1)
+ .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
+ .vocabMap(t.getVocab())
+ .task(BertIterator.Task.SEQ_CLASSIFICATION);
+ BertIterator shortLongPairFirstIter = commonBuilder
+ .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair)
+ .sentencePairProvider(new TestSentencePairProvider())
+ .prependToken(prependAppend ? "[CLS]" : null)
+ .appendToken(prependAppend ? "[SEP]" : null)
+ .build();
+ BertIterator shortFirstIter = commonBuilder
+ .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne)
+ .sentenceProvider(new TestSentenceProvider())
+ .prependToken(prependAppend ? "[CLS]" : null)
+ .appendToken(prependAppend ? "[SEP]" : null)
+ .build();
+ BertIterator longFirstIter = commonBuilder
+ .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo)
+ .sentenceProvider(new TestSentenceProvider(true))
+ .prependToken(null)
+ .appendToken(prependAppend ? "[SEP]" : null)
+ .build();
+ return new Triple<>(shortLongPairFirstIter.next(), shortFirstIter.next(), longFirstIter.next());
+ }
+
private static class TestSentenceProvider implements LabeledSentenceProvider {
private int pos = 0;
+ private boolean invert;
+
+ private TestSentenceProvider() {
+ this.invert = false;
+ }
+
+ private TestSentenceProvider(boolean invert) {
+ this.invert = invert;
+ }
@Override
public boolean hasNext() {
@@ -317,10 +564,20 @@ public class TestBertIterator extends BaseDL4JTest {
@Override
public Pair nextSentence() {
Preconditions.checkState(hasNext());
- if(pos++ == 0){
- return new Pair<>("I saw a girl with a telescope.", "positive");
- } else {
+ if (pos == 0) {
+ pos++;
+ if (!invert) return new Pair<>("I saw a girl with a telescope.", "positive");
return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative");
+ } else {
+ if (pos == 1) {
+ pos++;
+ if (!invert) return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative");
+ return new Pair<>("I saw a girl with a telescope.", "positive");
+ }
+ pos++;
+ if (!invert)
+ return new Pair<>("Goodnight noises everywhere", "positive");
+ return new Pair<>("Goodnight moon", "positive");
}
}
@@ -331,8 +588,54 @@ public class TestBertIterator extends BaseDL4JTest {
@Override
public int totalNumSentences() {
+ return 3;
+ }
+
+ @Override
+ public List allLabels() {
+ return Arrays.asList("positive", "negative");
+ }
+
+ @Override
+ public int numLabelClasses() {
return 2;
}
+ }
+
+ private static class TestSentencePairProvider implements LabeledPairSentenceProvider {
+
+ private int pos = 0;
+
+ @Override
+ public boolean hasNext() {
+ return pos < totalNumSentences();
+ }
+
+ @Override
+ public Triple nextSentencePair() {
+ Preconditions.checkState(hasNext());
+ if (pos == 0) {
+ pos++;
+ return new Triple<>("I saw a girl with a telescope.", "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "positive");
+ } else {
+ if (pos == 1) {
+ pos++;
+ return new Triple<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "I saw a girl with a telescope.", "negative");
+ }
+ pos++;
+ return new Triple<>("Goodnight noises everywhere", "Goodnight moon", "positive");
+ }
+ }
+
+ @Override
+ public void reset() {
+ pos = 0;
+ }
+
+ @Override
+ public int totalNumSentences() {
+ return 3;
+ }
@Override
public List allLabels() {