BertIterator sentence pair support for supervised training (#108)
* bert iterator sentence pair handling Signed-off-by: eraly <susan.eraly@gmail.com> * bert iterator sentence pair handling -seg Signed-off-by: eraly <susan.eraly@gmail.com> * bert iterator sentence pair handling tests Signed-off-by: eraly <susan.eraly@gmail.com> * test with pairs long done Signed-off-by: eraly <susan.eraly@gmail.com> * more tests with bert iter sent pairs done Signed-off-by: eraly <susan.eraly@gmail.com> * fixed copyright, formatting Signed-off-by: eraly <susan.eraly@gmail.com> * bert iterator - added featurizer for sentence pair inference Signed-off-by: eraly <susan.eraly@gmail.com> * bert iterator - finished tests Signed-off-by: eraly <susan.eraly@gmail.com> * bert iterator - finished tests, polish Signed-off-by: eraly <susan.eraly@gmail.com> * collection labeled sentence provider Signed-off-by: eraly <susan.eraly@gmail.com> * lombok fix for pojo class Signed-off-by: eraly <susan.eraly@gmail.com> * java doc misc clean up Signed-off-by: eraly <susan.eraly@gmail.com> * Private access modifiers Signed-off-by: AlexDBlack <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									9cc8803b8d
								
							
						
					
					
						commit
						91de96588c
					
				@ -34,6 +34,7 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
import org.nd4j.linalg.indexing.NDArrayIndex;
 | 
			
		||||
import org.nd4j.linalg.primitives.Pair;
 | 
			
		||||
import org.nd4j.linalg.primitives.Triple;
 | 
			
		||||
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
@ -85,10 +86,20 @@ import java.util.Map;
 | 
			
		||||
 * <pre>
 | 
			
		||||
 * {@code
 | 
			
		||||
 *          BertIterator b;
 | 
			
		||||
 *          Pair<INDArray[],INDArray[]> featuresAndMask;
 | 
			
		||||
 *          INDArray[] features;
 | 
			
		||||
 *          INDArray[] featureMasks;
 | 
			
		||||
 *
 | 
			
		||||
 *          //With sentences
 | 
			
		||||
 *          List<String> forInference;
 | 
			
		||||
 *          Pair<INDArray[],INDArray[]> featuresAndMask = b.featurizeSentences(forInference);
 | 
			
		||||
 *          INDArray[] features = featuresAndMask.getFirst();
 | 
			
		||||
 *          INDArray[] featureMasks = featuresAndMask.getSecond();
 | 
			
		||||
 *          featuresAndMask = b.featurizeSentences(forInference);
 | 
			
		||||
 *
 | 
			
		||||
 *          //OR with sentence pairs
 | 
			
		||||
 *          List<Pair<String, String>> forInferencePair};
 | 
			
		||||
 *          featuresAndMask = b.featurizeSentencePairs(forInference);
 | 
			
		||||
 *
 | 
			
		||||
 *          features = featuresAndMask.getFirst();
 | 
			
		||||
 *          featureMasks = featuresAndMask.getSecond();
 | 
			
		||||
 * }
 | 
			
		||||
 * </pre>
 | 
			
		||||
 * This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br>
 | 
			
		||||
@ -135,6 +146,7 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
    @Setter
 | 
			
		||||
    protected MultiDataSetPreProcessor preProcessor;
 | 
			
		||||
    protected LabeledSentenceProvider sentenceProvider = null;
 | 
			
		||||
    protected LabeledPairSentenceProvider sentencePairProvider = null;
 | 
			
		||||
    protected LengthHandling lengthHandling;
 | 
			
		||||
    protected FeatureArrays featureArrays;
 | 
			
		||||
    protected Map<String, Integer> vocabMap;   //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects?
 | 
			
		||||
@ -142,6 +154,7 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
    protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
 | 
			
		||||
    protected String maskToken;
 | 
			
		||||
    protected String prependToken;
 | 
			
		||||
    protected String appendToken;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    protected List<String> vocabKeysAsList;
 | 
			
		||||
@ -154,6 +167,7 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
        this.padMinibatches = b.padMinibatches;
 | 
			
		||||
        this.preProcessor = b.preProcessor;
 | 
			
		||||
        this.sentenceProvider = b.sentenceProvider;
 | 
			
		||||
        this.sentencePairProvider = b.sentencePairProvider;
 | 
			
		||||
        this.lengthHandling = b.lengthHandling;
 | 
			
		||||
        this.featureArrays = b.featureArrays;
 | 
			
		||||
        this.vocabMap = b.vocabMap;
 | 
			
		||||
@ -161,11 +175,14 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
        this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
 | 
			
		||||
        this.maskToken = b.maskToken;
 | 
			
		||||
        this.prependToken = b.prependToken;
 | 
			
		||||
        this.appendToken = b.appendToken;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public boolean hasNext() {
 | 
			
		||||
        if (sentenceProvider != null)
 | 
			
		||||
            return sentenceProvider.hasNext();
 | 
			
		||||
        return sentencePairProvider.hasNext();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
@ -181,29 +198,38 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
    @Override
 | 
			
		||||
    public MultiDataSet next(int num) {
 | 
			
		||||
        Preconditions.checkState(hasNext(), "No next element available");
 | 
			
		||||
 | 
			
		||||
        List<Pair<String, String>> list = new ArrayList<>(num);
 | 
			
		||||
        List<Pair<List<String>, String>> tokensAndLabelList;
 | 
			
		||||
        int mbSize = 0;
 | 
			
		||||
        int outLength;
 | 
			
		||||
        long[] segIdOnesFrom = null;
 | 
			
		||||
        if (sentenceProvider != null) {
 | 
			
		||||
            List<Pair<String, String>> list = new ArrayList<>(num);
 | 
			
		||||
            while (sentenceProvider.hasNext() && mbSize++ < num) {
 | 
			
		||||
                list.add(sentenceProvider.nextSentence());
 | 
			
		||||
            }
 | 
			
		||||
            SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(list);
 | 
			
		||||
            tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
 | 
			
		||||
            outLength = sentenceListProcessed.getMaxL();
 | 
			
		||||
        } else if (sentencePairProvider != null) {
 | 
			
		||||
            List<Triple<String, String, String>> listPairs = new ArrayList<>(num);
 | 
			
		||||
            while (sentencePairProvider.hasNext() && mbSize++ < num) {
 | 
			
		||||
                listPairs.add(sentencePairProvider.nextSentencePair());
 | 
			
		||||
            }
 | 
			
		||||
            SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(listPairs);
 | 
			
		||||
            tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
 | 
			
		||||
            outLength = sentencePairListProcessed.getMaxL();
 | 
			
		||||
            segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
 | 
			
		||||
        } else {
 | 
			
		||||
            //TODO - other types of iterators...
 | 
			
		||||
            throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list);
 | 
			
		||||
        List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
 | 
			
		||||
        int outLength = outLTokenizedSentencesPair.getLeft();
 | 
			
		||||
 | 
			
		||||
        Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength);
 | 
			
		||||
        Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
 | 
			
		||||
        INDArray[] featureArray = featuresAndMaskArraysPair.getFirst();
 | 
			
		||||
        INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond();
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        Pair<INDArray[], INDArray[]> labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength);
 | 
			
		||||
        Pair<INDArray[], INDArray[]> labelsAndMaskArraysPair = convertMiniBatchLabels(tokensAndLabelList, featureArray, outLength);
 | 
			
		||||
        INDArray[] labelArray = labelsAndMaskArraysPair.getFirst();
 | 
			
		||||
        INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond();
 | 
			
		||||
 | 
			
		||||
@ -224,32 +250,59 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
    public Pair<INDArray[], INDArray[]> featurizeSentences(List<String> listOnlySentences) {
 | 
			
		||||
 | 
			
		||||
        List<Pair<String, String>> sentencesWithNullLabel = addDummyLabel(listOnlySentences);
 | 
			
		||||
        SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(sentencesWithNullLabel);
 | 
			
		||||
        List<Pair<List<String>, String>> tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
 | 
			
		||||
        int outLength = sentenceListProcessed.getMaxL();
 | 
			
		||||
 | 
			
		||||
        Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel);
 | 
			
		||||
        List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
 | 
			
		||||
        int outLength = outLTokenizedSentencesPair.getLeft();
 | 
			
		||||
 | 
			
		||||
        Pair<INDArray[], INDArray[]> featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength);
 | 
			
		||||
        if (preProcessor != null) {
 | 
			
		||||
            Pair<INDArray[], INDArray[]> featureFeatureMasks = convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
 | 
			
		||||
            MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null);
 | 
			
		||||
            preProcessor.preProcess(dummyMDS);
 | 
			
		||||
            return new Pair<INDArray[],INDArray[]>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
 | 
			
		||||
            return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
 | 
			
		||||
        }
 | 
			
		||||
        return convertMiniBatchFeatures(tokenizedSentences, outLength);
 | 
			
		||||
        return convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> tokenizedSentences, int outLength) {
 | 
			
		||||
        int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size();
 | 
			
		||||
    /**
 | 
			
		||||
     * For use during inference. Will convert a given pair of a list of sentences to features and feature masks as appropriate.
 | 
			
		||||
     *
 | 
			
		||||
     * @param listOnlySentencePairs
 | 
			
		||||
     * @return Pair of INDArrays[], first element is feature arrays and the second is the masks array
 | 
			
		||||
     */
 | 
			
		||||
    public Pair<INDArray[], INDArray[]> featurizeSentencePairs(List<Pair<String, String>> listOnlySentencePairs) {
 | 
			
		||||
        Preconditions.checkState(sentencePairProvider != null, "The featurizeSentencePairs method is meant for inference with sentence pairs. Use only when the sentence pair provider is set (i.e not null).");
 | 
			
		||||
 | 
			
		||||
        List<Triple<String, String, String>> sentencePairsWithNullLabel = addDummyLabelForPairs(listOnlySentencePairs);
 | 
			
		||||
        SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(sentencePairsWithNullLabel);
 | 
			
		||||
        List<Pair<List<String>, String>> tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
 | 
			
		||||
        int outLength = sentencePairListProcessed.getMaxL();
 | 
			
		||||
        long[] segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
 | 
			
		||||
        if (preProcessor != null) {
 | 
			
		||||
            Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
 | 
			
		||||
            MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featuresAndMaskArraysPair.getFirst(), null, featuresAndMaskArraysPair.getSecond(), null);
 | 
			
		||||
            preProcessor.preProcess(dummyMDS);
 | 
			
		||||
            return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
 | 
			
		||||
        }
 | 
			
		||||
        return convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> tokensAndLabelList, int outLength, long[] segIdOnesFrom) {
 | 
			
		||||
        int mbPadded = padMinibatches ? minibatchSize : tokensAndLabelList.size();
 | 
			
		||||
        int[][] outIdxs = new int[mbPadded][outLength];
 | 
			
		||||
        int[][] outMask = new int[mbPadded][outLength];
 | 
			
		||||
        for (int i = 0; i < tokenizedSentences.size(); i++) {
 | 
			
		||||
            Pair<List<String>, String> p = tokenizedSentences.get(i);
 | 
			
		||||
        int[][] outSegmentId = null;
 | 
			
		||||
        if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID)
 | 
			
		||||
            outSegmentId = new int[mbPadded][outLength];
 | 
			
		||||
        for (int i = 0; i < tokensAndLabelList.size(); i++) {
 | 
			
		||||
            Pair<List<String>, String> p = tokensAndLabelList.get(i);
 | 
			
		||||
            List<String> t = p.getFirst();
 | 
			
		||||
            for (int j = 0; j < outLength && j < t.size(); j++) {
 | 
			
		||||
                Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j));
 | 
			
		||||
                int idx = vocabMap.get(t.get(j));
 | 
			
		||||
                outIdxs[i][j] = idx;
 | 
			
		||||
                outMask[i][j] = 1;
 | 
			
		||||
                if (segIdOnesFrom != null && j >= segIdOnesFrom[i])
 | 
			
		||||
                    outSegmentId[i][j] = 1;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@ -260,8 +313,7 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
        INDArray[] f;
 | 
			
		||||
        INDArray[] fm;
 | 
			
		||||
        if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
 | 
			
		||||
            //For now: always segment index 0 (only single s sequence input supported)
 | 
			
		||||
            outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength);
 | 
			
		||||
            outSegmentIdArr = Nd4j.createFromArray(outSegmentId);
 | 
			
		||||
            f = new INDArray[]{outIdxsArr, outSegmentIdArr};
 | 
			
		||||
            fm = new INDArray[]{outMaskArr, null};
 | 
			
		||||
        } else {
 | 
			
		||||
@ -271,16 +323,15 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
        return new Pair<>(f, fm);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Pair<Integer, List<Pair<List<String>, String>>> tokenizeMiniBatch(List<Pair<String, String>> list) {
 | 
			
		||||
    private SentenceListProcessed tokenizeMiniBatch(List<Pair<String, String>> list) {
 | 
			
		||||
        //Get and tokenize the sentences for this minibatch
 | 
			
		||||
        List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(list.size());
 | 
			
		||||
        SentenceListProcessed sentenceListProcessed = new SentenceListProcessed(list.size());
 | 
			
		||||
        int longestSeq = -1;
 | 
			
		||||
        for (Pair<String, String> p : list) {
 | 
			
		||||
            List<String> tokens = tokenizeSentence(p.getFirst());
 | 
			
		||||
            tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
 | 
			
		||||
            sentenceListProcessed.addProcessedToList(new Pair<>(tokens, p.getSecond()));
 | 
			
		||||
            longestSeq = Math.max(longestSeq, tokens.size());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        //Determine output array length...
 | 
			
		||||
        int outLength;
 | 
			
		||||
        switch (lengthHandling) {
 | 
			
		||||
@ -296,7 +347,52 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
            default:
 | 
			
		||||
                throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
 | 
			
		||||
        }
 | 
			
		||||
        return new Pair<>(outLength, tokenizedSentences);
 | 
			
		||||
        sentenceListProcessed.setMaxL(outLength);
 | 
			
		||||
        return sentenceListProcessed;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private SentencePairListProcessed tokenizePairsMiniBatch(List<Triple<String, String, String>> listPairs) {
 | 
			
		||||
        SentencePairListProcessed sentencePairListProcessed = new SentencePairListProcessed(listPairs.size());
 | 
			
		||||
        for (Triple<String, String, String> t : listPairs) {
 | 
			
		||||
            List<String> tokensL = tokenizeSentence(t.getFirst(), true);
 | 
			
		||||
            List<String> tokensR = tokenizeSentence(t.getSecond(), true);
 | 
			
		||||
            List<String> tokens = new ArrayList<>(maxTokens);
 | 
			
		||||
            int maxLength = maxTokens;
 | 
			
		||||
            if (prependToken != null)
 | 
			
		||||
                maxLength--;
 | 
			
		||||
            if (appendToken != null)
 | 
			
		||||
                maxLength -= 2;
 | 
			
		||||
            if (tokensL.size() + tokensR.size() > maxLength) {
 | 
			
		||||
                boolean shortOnL = tokensL.size() < tokensR.size();
 | 
			
		||||
                int shortSize = Math.min(tokensL.size(), tokensR.size());
 | 
			
		||||
                if (shortSize > maxLength / 2) {
 | 
			
		||||
                    //both lists need to be sliced
 | 
			
		||||
                    tokensL.subList(maxLength / 2, tokensL.size()).clear(); //if maxsize/2 is odd pop extra on L side to match implementation in TF
 | 
			
		||||
                    tokensR.subList(maxLength - maxLength / 2, tokensR.size()).clear();
 | 
			
		||||
                } else {
 | 
			
		||||
                    //slice longer list
 | 
			
		||||
                    if (shortOnL) {
 | 
			
		||||
                        //longer on R - slice R
 | 
			
		||||
                        tokensR.subList(maxLength - tokensL.size(), tokensR.size()).clear();
 | 
			
		||||
                    } else {
 | 
			
		||||
                        //longer on L - slice L
 | 
			
		||||
                        tokensL.subList(maxLength - tokensR.size(), tokensL.size()).clear();
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            if (prependToken != null)
 | 
			
		||||
                tokens.add(prependToken);
 | 
			
		||||
            tokens.addAll(tokensL);
 | 
			
		||||
            if (appendToken != null)
 | 
			
		||||
                tokens.add(appendToken);
 | 
			
		||||
            int segIdOnesFrom = tokens.size();
 | 
			
		||||
            tokens.addAll(tokensR);
 | 
			
		||||
            if (appendToken != null)
 | 
			
		||||
                tokens.add(appendToken);
 | 
			
		||||
            sentencePairListProcessed.addProcessedToList(segIdOnesFrom, new Pair<>(tokens, t.getThird()));
 | 
			
		||||
        }
 | 
			
		||||
        sentencePairListProcessed.setMaxL(maxTokens);
 | 
			
		||||
        return sentencePairListProcessed;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Pair<INDArray[], INDArray[]> convertMiniBatchLabels(List<Pair<List<String>, String>> tokenizedSentences, INDArray[] featureArray, int outLength) {
 | 
			
		||||
@ -316,6 +412,14 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
                    classLabels[i] = labels.indexOf(lbl);
 | 
			
		||||
                    Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
 | 
			
		||||
                }
 | 
			
		||||
            } else if (sentencePairProvider != null) {
 | 
			
		||||
                numClasses = sentencePairProvider.numLabelClasses();
 | 
			
		||||
                List<String> labels = sentencePairProvider.allLabels();
 | 
			
		||||
                for (int i = 0; i < mbSize; i++) {
 | 
			
		||||
                    String lbl = tokenizedSentences.get(i).getRight();
 | 
			
		||||
                    classLabels[i] = labels.indexOf(lbl);
 | 
			
		||||
                    Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
 | 
			
		||||
                }
 | 
			
		||||
            } else {
 | 
			
		||||
                throw new RuntimeException();
 | 
			
		||||
            }
 | 
			
		||||
@ -392,16 +496,22 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private List<String> tokenizeSentence(String sentence) {
 | 
			
		||||
        return tokenizeSentence(sentence, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private List<String> tokenizeSentence(String sentence, boolean ignorePrependAppend) {
 | 
			
		||||
        Tokenizer t = tokenizerFactory.create(sentence);
 | 
			
		||||
 | 
			
		||||
        List<String> tokens = new ArrayList<>();
 | 
			
		||||
        if (prependToken != null)
 | 
			
		||||
        if (prependToken != null && !ignorePrependAppend)
 | 
			
		||||
            tokens.add(prependToken);
 | 
			
		||||
 | 
			
		||||
        while (t.hasMoreTokens()) {
 | 
			
		||||
            String token = t.nextToken();
 | 
			
		||||
            tokens.add(token);
 | 
			
		||||
        }
 | 
			
		||||
        if (appendToken != null && !ignorePrependAppend)
 | 
			
		||||
            tokens.add(appendToken);
 | 
			
		||||
        return tokens;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -414,6 +524,13 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
        return list;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private List<Triple<String, String, String>> addDummyLabelForPairs(List<Pair<String, String>> listOnlySentencePairs) {
 | 
			
		||||
        List<Triple<String, String, String>> list = new ArrayList<>(listOnlySentencePairs.size());
 | 
			
		||||
        for (Pair<String, String> p : listOnlySentencePairs) {
 | 
			
		||||
            list.add(new Triple<String, String, String>(p.getFirst(), p.getSecond(), null));
 | 
			
		||||
        }
 | 
			
		||||
        return list;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public boolean resetSupported() {
 | 
			
		||||
@ -446,12 +563,14 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
        protected boolean padMinibatches = false;
 | 
			
		||||
        protected MultiDataSetPreProcessor preProcessor;
 | 
			
		||||
        protected LabeledSentenceProvider sentenceProvider = null;
 | 
			
		||||
        protected LabeledPairSentenceProvider sentencePairProvider = null;
 | 
			
		||||
        protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
 | 
			
		||||
        protected Map<String, Integer> vocabMap;   //TODO maybe use Eclipse ObjectIntHashMap for fewer objects?
 | 
			
		||||
        protected BertSequenceMasker masker = new BertMaskedLMMasker();
 | 
			
		||||
        protected UnsupervisedLabelFormat unsupervisedLabelFormat;
 | 
			
		||||
        protected String maskToken;
 | 
			
		||||
        protected String prependToken;
 | 
			
		||||
        protected String appendToken;
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
 | 
			
		||||
@ -519,14 +638,21 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised
 | 
			
		||||
         * use case, the labels will be ignored.
 | 
			
		||||
         * Specify the source of the data for classification.
 | 
			
		||||
         */
 | 
			
		||||
        public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) {
 | 
			
		||||
            this.sentenceProvider = sentenceProvider;
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Specify the source of the data for classification on sentence pairs.
 | 
			
		||||
         */
 | 
			
		||||
        public Builder sentencePairProvider(LabeledPairSentenceProvider sentencePairProvider) {
 | 
			
		||||
            this.sentencePairProvider = sentencePairProvider;
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Specify what arrays should be returned. See {@link BertIterator} for more details.
 | 
			
		||||
         */
 | 
			
		||||
@ -591,6 +717,19 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Append the specified token to the sequences, when doing training on sentence pairs.<br>
 | 
			
		||||
         * Generally "[SEP]" is used
 | 
			
		||||
         * No token in appended by default.
 | 
			
		||||
         *
 | 
			
		||||
         * @param appendToken Token at end of each sentence for pairs of sentences (null: no token will be appended)
 | 
			
		||||
         * @return
 | 
			
		||||
         */
 | 
			
		||||
        public Builder appendToken(String appendToken) {
 | 
			
		||||
            this.appendToken = appendToken;
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        public BertIterator build() {
 | 
			
		||||
            Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
 | 
			
		||||
            Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
 | 
			
		||||
@ -598,9 +737,69 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
            Preconditions.checkState(task != Task.UNSUPERVISED || masker != null, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method");
 | 
			
		||||
            Preconditions.checkState(task != Task.UNSUPERVISED || unsupervisedLabelFormat != null, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method");
 | 
			
		||||
            Preconditions.checkState(task != Task.UNSUPERVISED || maskToken != null, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified");
 | 
			
		||||
 | 
			
		||||
            if (sentencePairProvider != null) {
 | 
			
		||||
                Preconditions.checkState(task == Task.SEQ_CLASSIFICATION, "Currently only supervised sequence classification is set up with sentence pairs. \".task(BertIterator.Task.SEQ_CLASSIFICATION)\" is required with a sentence pair provider");
 | 
			
		||||
                Preconditions.checkState(featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID, "Currently only supervised sequence classification is set up with sentence pairs. \".featureArrays(FeatureArrays.INDICES_MASK_SEGMENTID)\" is required with a sentence pair provider");
 | 
			
		||||
                Preconditions.checkState(lengthHandling == LengthHandling.FIXED_LENGTH, "Currently only fixed length is supported for sentence pairs. \".lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxLength)\" is required with a sentence pair provider");
 | 
			
		||||
                Preconditions.checkState(sentencePairProvider != null, "Provide either a sentence provider or a sentence pair provider. Both cannot be non null");
 | 
			
		||||
            }
 | 
			
		||||
            if (appendToken != null) {
 | 
			
		||||
                Preconditions.checkState(sentencePairProvider != null, "Tokens are only appended with sentence pairs. Sentence pair provider is not set. Set sentence pair provider.");
 | 
			
		||||
            }
 | 
			
		||||
            return new BertIterator(this);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static class SentencePairListProcessed {
 | 
			
		||||
        private int listLength = 0;
 | 
			
		||||
 | 
			
		||||
        @Getter
 | 
			
		||||
        private long[] segIdOnesFrom;
 | 
			
		||||
        private int cursor = 0;
 | 
			
		||||
        private SentenceListProcessed sentenceListProcessed;
 | 
			
		||||
 | 
			
		||||
        private SentencePairListProcessed(int listLength) {
 | 
			
		||||
            this.listLength = listLength;
 | 
			
		||||
            segIdOnesFrom = new long[listLength];
 | 
			
		||||
            sentenceListProcessed = new SentenceListProcessed(listLength);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private void addProcessedToList(long segIdIdx, Pair<List<String>, String> tokenizedSentencePairAndLabel) {
 | 
			
		||||
            segIdOnesFrom[cursor] = segIdIdx;
 | 
			
		||||
            sentenceListProcessed.addProcessedToList(tokenizedSentencePairAndLabel);
 | 
			
		||||
            cursor++;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private void setMaxL(int maxL) {
 | 
			
		||||
            sentenceListProcessed.setMaxL(maxL);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private int getMaxL() {
 | 
			
		||||
            return sentenceListProcessed.getMaxL();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private List<Pair<List<String>, String>> getTokensAndLabelList() {
 | 
			
		||||
            return sentenceListProcessed.getTokensAndLabelList();
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static class SentenceListProcessed {
 | 
			
		||||
        private int listLength;
 | 
			
		||||
 | 
			
		||||
        @Getter
 | 
			
		||||
        @Setter
 | 
			
		||||
        private int maxL;
 | 
			
		||||
 | 
			
		||||
        @Getter
 | 
			
		||||
        private List<Pair<List<String>, String>> tokensAndLabelList;
 | 
			
		||||
 | 
			
		||||
        private SentenceListProcessed(int listLength) {
 | 
			
		||||
            this.listLength = listLength;
 | 
			
		||||
            tokensAndLabelList = new ArrayList<>(listLength);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private void addProcessedToList(Pair<List<String>, String> tokenizedSentenceAndLabel) {
 | 
			
		||||
            tokensAndLabelList.add(tokenizedSentenceAndLabel);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,60 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2019 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.deeplearning4j.iterator;
 | 
			
		||||
 | 
			
		||||
import org.nd4j.linalg.primitives.Triple;
 | 
			
		||||
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * LabeledPairSentenceProvider: a simple iterator interface over a pair of sentences/documents that have a label.<br>
 | 
			
		||||
 */
 | 
			
		||||
public interface LabeledPairSentenceProvider {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Are there more sentences/documents available?
 | 
			
		||||
     */
 | 
			
		||||
    boolean hasNext();
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @return Triple: two sentence/document texts and label
 | 
			
		||||
     */
 | 
			
		||||
    Triple<String, String, String> nextSentencePair();
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Reset the iterator - including shuffling the order, if necessary/appropriate
 | 
			
		||||
     */
 | 
			
		||||
    void reset();
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Return the total number of sentences, or -1 if not available
 | 
			
		||||
     */
 | 
			
		||||
    int totalNumSentences();
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Return the list of labels - this also defines the class/integer label assignment order
 | 
			
		||||
     */
 | 
			
		||||
    List<String> allLabels();
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Equivalent to allLabels().size()
 | 
			
		||||
     */
 | 
			
		||||
    int numLabelClasses();
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,135 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2019 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.deeplearning4j.iterator.provider;
 | 
			
		||||
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import org.deeplearning4j.iterator.LabeledPairSentenceProvider;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.primitives.Triple;
 | 
			
		||||
import org.nd4j.linalg.util.MathUtils;
 | 
			
		||||
 | 
			
		||||
import java.util.*;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Iterate over a pair of sentences/documents,
 | 
			
		||||
 * where the sentences and labels are provided in lists.
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
public class CollectionLabeledPairSentenceProvider implements LabeledPairSentenceProvider {
 | 
			
		||||
 | 
			
		||||
    private final List<String> sentenceL;
 | 
			
		||||
    private final List<String> sentenceR;
 | 
			
		||||
    private final List<String> labels;
 | 
			
		||||
    private final Random rng;
 | 
			
		||||
    private final int[] order;
 | 
			
		||||
    private final List<String> allLabels;
 | 
			
		||||
 | 
			
		||||
    private int cursor = 0;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Lists containing sentences to iterate over with a third for labels
 | 
			
		||||
     * Sentences in the same position in the first two lists are considered a pair
 | 
			
		||||
     * @param sentenceL
 | 
			
		||||
     * @param sentenceR
 | 
			
		||||
     * @param labelsForSentences
 | 
			
		||||
     */
 | 
			
		||||
    public CollectionLabeledPairSentenceProvider(@NonNull List<String> sentenceL, @NonNull List<String> sentenceR,
 | 
			
		||||
                                                 @NonNull List<String> labelsForSentences) {
 | 
			
		||||
        this(sentenceL, sentenceR, labelsForSentences, new Random());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Lists containing sentences to iterate over with a third for labels
 | 
			
		||||
     * Sentences in the same position in the first two lists are considered a pair
 | 
			
		||||
     * @param sentenceL
 | 
			
		||||
     * @param sentenceR
 | 
			
		||||
     * @param labelsForSentences
 | 
			
		||||
     * @param rng If null, list order is not shuffled
 | 
			
		||||
     */
 | 
			
		||||
    public CollectionLabeledPairSentenceProvider(@NonNull List<String> sentenceL, List<String> sentenceR, @NonNull List<String> labelsForSentences,
 | 
			
		||||
                                                 Random rng) {
 | 
			
		||||
        if (sentenceR.size() != sentenceL.size()) {
 | 
			
		||||
            throw new IllegalArgumentException("Sentence lists must be same size (first list size: "
 | 
			
		||||
                    + sentenceL.size() + ", second list size: " + sentenceR.size() + ")");
 | 
			
		||||
        }
 | 
			
		||||
        if (sentenceR.size() != labelsForSentences.size()) {
 | 
			
		||||
            throw new IllegalArgumentException("Sentence pairs and labels must be same size (sentence pair size: "
 | 
			
		||||
                    + sentenceR.size() + ", labels size: " + labelsForSentences.size() + ")");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        this.sentenceL = sentenceL;
 | 
			
		||||
        this.sentenceR = sentenceR;
 | 
			
		||||
        this.labels = labelsForSentences;
 | 
			
		||||
        this.rng = rng;
 | 
			
		||||
        if (rng == null) {
 | 
			
		||||
            order = null;
 | 
			
		||||
        } else {
 | 
			
		||||
            order = new int[sentenceR.size()];
 | 
			
		||||
            for (int i = 0; i < sentenceR.size(); i++) {
 | 
			
		||||
                order[i] = i;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            MathUtils.shuffleArray(order, rng);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        //Collect set of unique labels for all sentences
 | 
			
		||||
        Set<String> uniqueLabels = new HashSet<>(labelsForSentences);
 | 
			
		||||
        allLabels = new ArrayList<>(uniqueLabels);
 | 
			
		||||
        Collections.sort(allLabels);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public boolean hasNext() {
 | 
			
		||||
        return cursor < sentenceR.size();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Triple<String, String, String> nextSentencePair() {
 | 
			
		||||
        Preconditions.checkState(hasNext(),"No next element available");
 | 
			
		||||
        int idx;
 | 
			
		||||
        if (rng == null) {
 | 
			
		||||
            idx = cursor++;
 | 
			
		||||
        } else {
 | 
			
		||||
            idx = order[cursor++];
 | 
			
		||||
        }
 | 
			
		||||
        return new Triple<>(sentenceL.get(idx), sentenceR.get(idx), labels.get(idx));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public void reset() {
 | 
			
		||||
        cursor = 0;
 | 
			
		||||
        if (rng != null) {
 | 
			
		||||
            MathUtils.shuffleArray(order, rng);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public int totalNumSentences() {
 | 
			
		||||
        return sentenceR.size();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<String> allLabels() {
 | 
			
		||||
        return allLabels;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public int numLabelClasses() {
 | 
			
		||||
        return allLabels.size();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -18,6 +18,7 @@ package org.deeplearning4j.iterator.provider;
 | 
			
		||||
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import org.deeplearning4j.iterator.LabeledSentenceProvider;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.primitives.Pair;
 | 
			
		||||
import org.nd4j.linalg.util.MathUtils;
 | 
			
		||||
 | 
			
		||||
@ -66,10 +67,7 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        //Collect set of unique labels for all sentences
 | 
			
		||||
        Set<String> uniqueLabels = new HashSet<>();
 | 
			
		||||
        for (String s : labelsForSentences) {
 | 
			
		||||
            uniqueLabels.add(s);
 | 
			
		||||
        }
 | 
			
		||||
        Set<String> uniqueLabels = new HashSet<>(labelsForSentences);
 | 
			
		||||
        allLabels = new ArrayList<>(uniqueLabels);
 | 
			
		||||
        Collections.sort(allLabels);
 | 
			
		||||
    }
 | 
			
		||||
@ -81,6 +79,7 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Pair<String, String> nextSentence() {
 | 
			
		||||
        Preconditions.checkState(hasNext(), "No next element available");
 | 
			
		||||
        int idx;
 | 
			
		||||
        if (rng == null) {
 | 
			
		||||
            idx = cursor++;
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,6 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
			
		||||
 * Copyright (c) 2019 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
@ -27,6 +28,7 @@ import org.nd4j.linalg.dataset.api.MultiDataSet;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
import org.nd4j.linalg.indexing.NDArrayIndex;
 | 
			
		||||
import org.nd4j.linalg.primitives.Pair;
 | 
			
		||||
import org.nd4j.linalg.primitives.Triple;
 | 
			
		||||
import org.nd4j.resources.Resources;
 | 
			
		||||
 | 
			
		||||
import java.io.File;
 | 
			
		||||
@ -43,7 +45,8 @@ public class TestBertIterator extends BaseDL4JTest {
 | 
			
		||||
    private File pathToVocab = Resources.asFile("other/vocab.txt");
 | 
			
		||||
    private static Charset c = StandardCharsets.UTF_8;
 | 
			
		||||
 | 
			
		||||
    public TestBertIterator() throws IOException{ }
 | 
			
		||||
    public TestBertIterator() throws IOException {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(timeout = 20000L)
 | 
			
		||||
    public void testBertSequenceClassification() throws Exception {
 | 
			
		||||
@ -102,10 +105,10 @@ public class TestBertIterator extends BaseDL4JTest {
 | 
			
		||||
        assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]);
 | 
			
		||||
        assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
 | 
			
		||||
 | 
			
		||||
        b.next(); //pop the third element
 | 
			
		||||
        assertFalse(b.hasNext());
 | 
			
		||||
        b.reset();
 | 
			
		||||
        assertTrue(b.hasNext());
 | 
			
		||||
        MultiDataSet mds2 = b.next();
 | 
			
		||||
 | 
			
		||||
        forInference.set(0, toTokenize2);
 | 
			
		||||
        //Same thing, but with segment ID also
 | 
			
		||||
@ -120,6 +123,7 @@ public class TestBertIterator extends BaseDL4JTest {
 | 
			
		||||
                .build();
 | 
			
		||||
        mds = b.next();
 | 
			
		||||
        assertEquals(2, mds.getFeatures().length);
 | 
			
		||||
        //assertEquals(2, mds.getFeaturesMaskArrays().length); second element is null...
 | 
			
		||||
        assertEquals(2, b.featurizeSentences(forInference).getFirst().length);
 | 
			
		||||
        //Segment ID should be all 0s for single segment task
 | 
			
		||||
        INDArray segmentId = expM.like();
 | 
			
		||||
@ -152,9 +156,10 @@ public class TestBertIterator extends BaseDL4JTest {
 | 
			
		||||
        System.out.println(mds.getLabels(0));
 | 
			
		||||
        System.out.println(mds.getLabelsMaskArray(0));
 | 
			
		||||
 | 
			
		||||
        b.next(); //pop the third element
 | 
			
		||||
        assertFalse(b.hasNext());
 | 
			
		||||
        b.reset();
 | 
			
		||||
        mds = b.next();
 | 
			
		||||
        assertTrue(b.hasNext());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(timeout = 20000L)
 | 
			
		||||
@ -168,6 +173,7 @@ public class TestBertIterator extends BaseDL4JTest {
 | 
			
		||||
        INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
 | 
			
		||||
        INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
 | 
			
		||||
        List<String> tokens = t.create(toTokenize1).getTokens();
 | 
			
		||||
        System.out.println(tokens);
 | 
			
		||||
        Map<String, Integer> m = t.getVocab();
 | 
			
		||||
        for (int i = 0; i < tokens.size(); i++) {
 | 
			
		||||
            int idx = m.get(tokens.get(i));
 | 
			
		||||
@ -178,6 +184,7 @@ public class TestBertIterator extends BaseDL4JTest {
 | 
			
		||||
        INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
 | 
			
		||||
        INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
 | 
			
		||||
        List<String> tokens2 = t.create(toTokenize2).getTokens();
 | 
			
		||||
        System.out.println(tokens2);
 | 
			
		||||
        for (int i = 0; i < tokens2.size(); i++) {
 | 
			
		||||
            String token = tokens2.get(i);
 | 
			
		||||
            if (!m.containsKey(token)) {
 | 
			
		||||
@ -236,9 +243,11 @@ public class TestBertIterator extends BaseDL4JTest {
 | 
			
		||||
        Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
 | 
			
		||||
        String toTokenize1 = "I saw a girl with a telescope.";
 | 
			
		||||
        String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
 | 
			
		||||
        String toTokenize3 = "Goodnight noises everywhere";
 | 
			
		||||
        List<String> forInference = new ArrayList<>();
 | 
			
		||||
        forInference.add(toTokenize1);
 | 
			
		||||
        forInference.add(toTokenize2);
 | 
			
		||||
        forInference.add(toTokenize3);
 | 
			
		||||
        BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
 | 
			
		||||
        INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
 | 
			
		||||
        INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
 | 
			
		||||
@ -263,14 +272,27 @@ public class TestBertIterator extends BaseDL4JTest {
 | 
			
		||||
            expM1.putScalar(0, i, 1);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        INDArray zeros = Nd4j.create(DataType.INT, 2, 16);
 | 
			
		||||
        INDArray expEx3 = Nd4j.create(DataType.INT, 1, 16);
 | 
			
		||||
        INDArray expM3 = Nd4j.create(DataType.INT, 1, 16);
 | 
			
		||||
        List<String> tokens3 = t.create(toTokenize3).getTokens();
 | 
			
		||||
        for (int i = 0; i < tokens3.size(); i++) {
 | 
			
		||||
            String token = tokens3.get(i);
 | 
			
		||||
            if (!m.containsKey(token)) {
 | 
			
		||||
                throw new IllegalStateException("Unknown token: \"" + token + "\"");
 | 
			
		||||
            }
 | 
			
		||||
            int idx = m.get(token);
 | 
			
		||||
            expEx3.putScalar(0, i, idx);
 | 
			
		||||
            expM3.putScalar(0, i, 1);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        INDArray expF = Nd4j.vstack(expEx0, expEx1, zeros);
 | 
			
		||||
        INDArray expM = Nd4j.vstack(expM0, expM1, zeros);
 | 
			
		||||
        INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {0, 0}, {0, 0}});
 | 
			
		||||
        INDArray zeros = Nd4j.create(DataType.INT, 1, 16);
 | 
			
		||||
        INDArray expF = Nd4j.vstack(expEx0, expEx1, expEx3, zeros);
 | 
			
		||||
        INDArray expM = Nd4j.vstack(expM0, expM1, expM3, zeros);
 | 
			
		||||
        INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {1, 0}, {0, 0}});
 | 
			
		||||
        INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1);
 | 
			
		||||
        expLM.putScalar(0, 0, 1);
 | 
			
		||||
        expLM.putScalar(1, 0, 1);
 | 
			
		||||
        expLM.putScalar(2, 0, 1);
 | 
			
		||||
 | 
			
		||||
        //--------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
@ -305,9 +327,234 @@ public class TestBertIterator extends BaseDL4JTest {
 | 
			
		||||
        assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testSentencePairsSingle() throws IOException {
 | 
			
		||||
        String shortSent = "I saw a girl with a telescope.";
 | 
			
		||||
        String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
 | 
			
		||||
        boolean prependAppend;
 | 
			
		||||
        BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
 | 
			
		||||
        int shortL = t.create(shortSent).countTokens();
 | 
			
		||||
        int longL = t.create(longSent).countTokens();
 | 
			
		||||
 | 
			
		||||
        Triple<MultiDataSet, MultiDataSet, MultiDataSet> multiDataSetTriple;
 | 
			
		||||
        MultiDataSet shortLongPair, shortSentence, longSentence;
 | 
			
		||||
 | 
			
		||||
        // check for pair max length exactly equal to sum of lengths - pop neither no padding
 | 
			
		||||
        // should be the same as hstack with segment ids 1 for second sentence
 | 
			
		||||
        prependAppend = true;
 | 
			
		||||
        multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend);
 | 
			
		||||
        shortLongPair = multiDataSetTriple.getFirst();
 | 
			
		||||
        shortSentence = multiDataSetTriple.getSecond();
 | 
			
		||||
        longSentence = multiDataSetTriple.getThird();
 | 
			
		||||
        assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
 | 
			
		||||
        longSentence.getFeatures(1).addi(1);
 | 
			
		||||
        assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
 | 
			
		||||
        assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
 | 
			
		||||
 | 
			
		||||
        //check for pair max length greater than sum of lengths - pop neither with padding
 | 
			
		||||
        // features should be the same as hstack of shorter and longer padded with prepend/append
 | 
			
		||||
        // segment id should 1 only in the longer for part of the length of the sentence
 | 
			
		||||
        prependAppend = true;
 | 
			
		||||
        multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend);
 | 
			
		||||
        shortLongPair = multiDataSetTriple.getFirst();
 | 
			
		||||
        shortSentence = multiDataSetTriple.getSecond();
 | 
			
		||||
        longSentence = multiDataSetTriple.getThird();
 | 
			
		||||
        assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
 | 
			
		||||
        longSentence.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part
 | 
			
		||||
        assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
 | 
			
		||||
        assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
 | 
			
		||||
 | 
			
		||||
        //check for pair max length less than shorter sentence - pop both
 | 
			
		||||
        //should be the same as hstack with segment ids 1 for second sentence if no prepend/append
 | 
			
		||||
        int maxL = shortL - 2;
 | 
			
		||||
        prependAppend = false;
 | 
			
		||||
        multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend);
 | 
			
		||||
        shortLongPair = multiDataSetTriple.getFirst();
 | 
			
		||||
        shortSentence = multiDataSetTriple.getSecond();
 | 
			
		||||
        longSentence = multiDataSetTriple.getThird();
 | 
			
		||||
        assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
 | 
			
		||||
        longSentence.getFeatures(1).addi(1);
 | 
			
		||||
        assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
 | 
			
		||||
        assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testSentencePairsUnequalLengths() throws IOException {
 | 
			
		||||
        //check for pop only longer (i.e between longer and longer + shorter), first row pop from second sentence, next row pop from first sentence, nothing to pop in the third row
 | 
			
		||||
        //should be identical to hstack if there is no append, prepend
 | 
			
		||||
        //batch size is 2
 | 
			
		||||
        int mbS = 4;
 | 
			
		||||
        String shortSent = "I saw a girl with a telescope.";
 | 
			
		||||
        String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
 | 
			
		||||
        String sent1 = "Goodnight noises everywhere"; //shorter than shortSent - no popping
 | 
			
		||||
        String sent2 = "Goodnight moon"; //shorter than shortSent - no popping
 | 
			
		||||
        BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
 | 
			
		||||
        int shortL = t.create(shortSent).countTokens();
 | 
			
		||||
        int longL = t.create(longSent).countTokens();
 | 
			
		||||
        int sent1L = t.create(sent1).countTokens();
 | 
			
		||||
        int sent2L = t.create(sent2).countTokens();
 | 
			
		||||
        //won't check 2*shortL + 1 because this will always pop on the left
 | 
			
		||||
        for (int maxL = longL + shortL - 1; maxL > 2 * shortL; maxL--) {
 | 
			
		||||
            MultiDataSet leftMDS = BertIterator.builder()
 | 
			
		||||
                    .tokenizer(t)
 | 
			
		||||
                    .minibatchSize(mbS)
 | 
			
		||||
                    .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
 | 
			
		||||
                    .vocabMap(t.getVocab())
 | 
			
		||||
                    .task(BertIterator.Task.SEQ_CLASSIFICATION)
 | 
			
		||||
                    .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either
 | 
			
		||||
                    .sentenceProvider(new TestSentenceProvider())
 | 
			
		||||
                    .padMinibatches(true)
 | 
			
		||||
                    .build().next();
 | 
			
		||||
 | 
			
		||||
            MultiDataSet rightMDS = BertIterator.builder()
 | 
			
		||||
                    .tokenizer(t)
 | 
			
		||||
                    .minibatchSize(mbS)
 | 
			
		||||
                    .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
 | 
			
		||||
                    .vocabMap(t.getVocab())
 | 
			
		||||
                    .task(BertIterator.Task.SEQ_CLASSIFICATION)
 | 
			
		||||
                    .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either
 | 
			
		||||
                    .sentenceProvider(new TestSentenceProvider(true))
 | 
			
		||||
                    .padMinibatches(true)
 | 
			
		||||
                    .build().next();
 | 
			
		||||
 | 
			
		||||
            MultiDataSet pairMDS = BertIterator.builder()
 | 
			
		||||
                    .tokenizer(t)
 | 
			
		||||
                    .minibatchSize(mbS)
 | 
			
		||||
                    .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
 | 
			
		||||
                    .vocabMap(t.getVocab())
 | 
			
		||||
                    .task(BertIterator.Task.SEQ_CLASSIFICATION)
 | 
			
		||||
                    .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) //random big num guaranteed to be longer than either
 | 
			
		||||
                    .sentencePairProvider(new TestSentencePairProvider())
 | 
			
		||||
                    .padMinibatches(true)
 | 
			
		||||
                    .build().next();
 | 
			
		||||
 | 
			
		||||
            //Left sentences here are {{shortSent},
 | 
			
		||||
            //                         {longSent},
 | 
			
		||||
            //                         {Sent1}}
 | 
			
		||||
            //Right sentences here are {{longSent},
 | 
			
		||||
            //                         {shortSent},
 | 
			
		||||
            //                         {Sent2}}
 | 
			
		||||
            //The sentence pairs here are {{shortSent,longSent},
 | 
			
		||||
            //                             {longSent,shortSent}
 | 
			
		||||
            //                             {Sent1, Sent2}}
 | 
			
		||||
 | 
			
		||||
            //CHECK FEATURES
 | 
			
		||||
            INDArray combinedFeat = Nd4j.create(DataType.INT,mbS,maxL);
 | 
			
		||||
            //left side
 | 
			
		||||
            INDArray leftFeatures = leftMDS.getFeatures(0);
 | 
			
		||||
            INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL));
 | 
			
		||||
            INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL));
 | 
			
		||||
            INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0,sent1L));
 | 
			
		||||
            //right side
 | 
			
		||||
            INDArray rightFeatures = rightMDS.getFeatures(0);
 | 
			
		||||
            INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL));
 | 
			
		||||
            INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL));
 | 
			
		||||
            INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0,sent2L));
 | 
			
		||||
            //expected pair
 | 
			
		||||
            combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat,topRSentFeat));
 | 
			
		||||
            combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat,midRSentFeat));
 | 
			
		||||
            combinedFeat.getRow(2).get(NDArrayIndex.interval(0,sent1L+sent2L)).addi(Nd4j.hstack(bottomLSentFeat,bottomRSentFeat));
 | 
			
		||||
 | 
			
		||||
            assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]);
 | 
			
		||||
            assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape());
 | 
			
		||||
            assertEquals(combinedFeat, pairMDS.getFeatures(0));
 | 
			
		||||
 | 
			
		||||
            //CHECK SEGMENT ID
 | 
			
		||||
            INDArray combinedFetSeg = Nd4j.create(DataType.INT, mbS, maxL);
 | 
			
		||||
            combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1);
 | 
			
		||||
            combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1);
 | 
			
		||||
            combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L+sent2L)).addi(1);
 | 
			
		||||
            assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape());
 | 
			
		||||
            assertEquals(maxL, combinedFetSeg.shape()[1]);
 | 
			
		||||
            assertEquals(combinedFetSeg, pairMDS.getFeatures(1));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testSentencePairFeaturizer() throws IOException {
 | 
			
		||||
        String shortSent = "I saw a girl with a telescope.";
 | 
			
		||||
        String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
 | 
			
		||||
        List<Pair<String, String>> listSentencePair = new ArrayList<>();
 | 
			
		||||
        listSentencePair.add(new Pair<>(shortSent, longSent));
 | 
			
		||||
        listSentencePair.add(new Pair<>(longSent, shortSent));
 | 
			
		||||
        BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
 | 
			
		||||
        BertIterator b = BertIterator.builder()
 | 
			
		||||
                .tokenizer(t)
 | 
			
		||||
                .minibatchSize(2)
 | 
			
		||||
                .padMinibatches(true)
 | 
			
		||||
                .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
 | 
			
		||||
                .vocabMap(t.getVocab())
 | 
			
		||||
                .task(BertIterator.Task.SEQ_CLASSIFICATION)
 | 
			
		||||
                .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128)
 | 
			
		||||
                .sentencePairProvider(new TestSentencePairProvider())
 | 
			
		||||
                .prependToken("[CLS]")
 | 
			
		||||
                .appendToken("[SEP]")
 | 
			
		||||
                .build();
 | 
			
		||||
        MultiDataSet mds = b.next();
 | 
			
		||||
        INDArray[] featuresArr = mds.getFeatures();
 | 
			
		||||
        INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays();
 | 
			
		||||
 | 
			
		||||
        Pair<INDArray[], INDArray[]> p = b.featurizeSentencePairs(listSentencePair);
 | 
			
		||||
        assertEquals(p.getFirst().length, 2);
 | 
			
		||||
        assertEquals(featuresArr[0], p.getFirst()[0]);
 | 
			
		||||
        assertEquals(featuresArr[1], p.getFirst()[1]);
 | 
			
		||||
        //assertEquals(p.getSecond().length, 2);
 | 
			
		||||
        assertEquals(featuresMaskArr[0], p.getSecond()[0]);
 | 
			
		||||
        //assertEquals(featuresMaskArr[1], p.getSecond()[1]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Returns three multidatasets from bert iterator based on given max lengths and whether to prepend/append
 | 
			
		||||
     * Idea is the sentence pair dataset can be constructed from the single sentence datasets
 | 
			
		||||
     * First one is constructed from a sentence pair "I saw a girl with a telescope." & "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
 | 
			
		||||
     * Second one is constructed from the left of the sentence pair i.e "I saw a girl with a telescope."
 | 
			
		||||
     * Third one is constructed from the right of the sentence pair i.e "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
 | 
			
		||||
     */
 | 
			
		||||
    private Triple<MultiDataSet, MultiDataSet, MultiDataSet> generateMultiDataSets(Triple<Integer, Integer, Integer> maxLengths, boolean prependAppend) throws IOException {
 | 
			
		||||
        BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
 | 
			
		||||
        int maxforPair = maxLengths.getFirst();
 | 
			
		||||
        int maxPartOne = maxLengths.getSecond();
 | 
			
		||||
        int maxPartTwo = maxLengths.getThird();
 | 
			
		||||
        BertIterator.Builder commonBuilder;
 | 
			
		||||
        commonBuilder = BertIterator.builder()
 | 
			
		||||
                .tokenizer(t)
 | 
			
		||||
                .minibatchSize(1)
 | 
			
		||||
                .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
 | 
			
		||||
                .vocabMap(t.getVocab())
 | 
			
		||||
                .task(BertIterator.Task.SEQ_CLASSIFICATION);
 | 
			
		||||
        BertIterator shortLongPairFirstIter = commonBuilder
 | 
			
		||||
                .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair)
 | 
			
		||||
                .sentencePairProvider(new TestSentencePairProvider())
 | 
			
		||||
                .prependToken(prependAppend ? "[CLS]" : null)
 | 
			
		||||
                .appendToken(prependAppend ? "[SEP]" : null)
 | 
			
		||||
                .build();
 | 
			
		||||
        BertIterator shortFirstIter = commonBuilder
 | 
			
		||||
                .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne)
 | 
			
		||||
                .sentenceProvider(new TestSentenceProvider())
 | 
			
		||||
                .prependToken(prependAppend ? "[CLS]" : null)
 | 
			
		||||
                .appendToken(prependAppend ? "[SEP]" : null)
 | 
			
		||||
                .build();
 | 
			
		||||
        BertIterator longFirstIter = commonBuilder
 | 
			
		||||
                .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo)
 | 
			
		||||
                .sentenceProvider(new TestSentenceProvider(true))
 | 
			
		||||
                .prependToken(null)
 | 
			
		||||
                .appendToken(prependAppend ? "[SEP]" : null)
 | 
			
		||||
                .build();
 | 
			
		||||
        return new Triple<>(shortLongPairFirstIter.next(), shortFirstIter.next(), longFirstIter.next());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static class TestSentenceProvider implements LabeledSentenceProvider {
 | 
			
		||||
 | 
			
		||||
        private int pos = 0;
 | 
			
		||||
        private boolean invert;
 | 
			
		||||
 | 
			
		||||
        private TestSentenceProvider() {
 | 
			
		||||
            this.invert = false;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private TestSentenceProvider(boolean invert) {
 | 
			
		||||
            this.invert = invert;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public boolean hasNext() {
 | 
			
		||||
@ -317,10 +564,20 @@ public class TestBertIterator extends BaseDL4JTest {
 | 
			
		||||
        @Override
 | 
			
		||||
        public Pair<String, String> nextSentence() {
 | 
			
		||||
            Preconditions.checkState(hasNext());
 | 
			
		||||
            if(pos++ == 0){
 | 
			
		||||
                return new Pair<>("I saw a girl with a telescope.", "positive");
 | 
			
		||||
            } else {
 | 
			
		||||
            if (pos == 0) {
 | 
			
		||||
                pos++;
 | 
			
		||||
                if (!invert) return new Pair<>("I saw a girl with a telescope.", "positive");
 | 
			
		||||
                return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative");
 | 
			
		||||
            } else {
 | 
			
		||||
                if (pos == 1) {
 | 
			
		||||
                    pos++;
 | 
			
		||||
                    if (!invert) return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative");
 | 
			
		||||
                    return new Pair<>("I saw a girl with a telescope.", "positive");
 | 
			
		||||
                }
 | 
			
		||||
                pos++;
 | 
			
		||||
                if (!invert)
 | 
			
		||||
                    return new Pair<>("Goodnight noises everywhere", "positive");
 | 
			
		||||
                return new Pair<>("Goodnight moon", "positive");
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@ -331,8 +588,54 @@ public class TestBertIterator extends BaseDL4JTest {
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public int totalNumSentences() {
 | 
			
		||||
            return 3;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public List<String> allLabels() {
 | 
			
		||||
            return Arrays.asList("positive", "negative");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public int numLabelClasses() {
 | 
			
		||||
            return 2;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static class TestSentencePairProvider implements LabeledPairSentenceProvider {
 | 
			
		||||
 | 
			
		||||
        private int pos = 0;
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public boolean hasNext() {
 | 
			
		||||
            return pos < totalNumSentences();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public Triple<String, String, String> nextSentencePair() {
 | 
			
		||||
            Preconditions.checkState(hasNext());
 | 
			
		||||
            if (pos == 0) {
 | 
			
		||||
                pos++;
 | 
			
		||||
                return new Triple<>("I saw a girl with a telescope.", "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "positive");
 | 
			
		||||
            } else {
 | 
			
		||||
                if (pos == 1) {
 | 
			
		||||
                    pos++;
 | 
			
		||||
                    return new Triple<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "I saw a girl with a telescope.", "negative");
 | 
			
		||||
                }
 | 
			
		||||
                pos++;
 | 
			
		||||
                return new Triple<>("Goodnight noises everywhere", "Goodnight moon", "positive");
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public void reset() {
 | 
			
		||||
            pos = 0;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public int totalNumSentences() {
 | 
			
		||||
            return 3;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public List<String> allLabels() {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user