diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java index c6a88ffb3..b5ca6c91a 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -79,6 +80,17 @@ import java.util.Map; * .build(); * } * + *
+ * Example to use an instantiated iterator for inference:
+ *
+ * {@code
+ *          BertIterator b;
+ *          List forInference;
+ *          Pair featuresAndMask = b.featurizeSentences(forInference);
+ *          INDArray[] features = featuresAndMask.getFirst();
+ *          INDArray[] featureMasks = featuresAndMask.getSecond();
+ * }
+ * 
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.
*
* {@link LengthHandling} configuration:
@@ -89,7 +101,7 @@ import java.util.Map; * CLIP_ONLY: For any sequences longer than the specified maximum, clip them. If the maximum sequence length in * a minibatch is shorter than the specified maximum, no padding will occur. For sequences that are shorter than the * maximum (within the current minibatch) they will be zero padded and masked.
- *

+ *

* {@link FeatureArrays} configuration:
* Determines what arrays should be included.
* INDICES_MASK: Indices array and mask array only, no segment ID array. Returns 1 feature array, 1 feature mask array (plus labels).
@@ -107,8 +119,11 @@ import java.util.Map; public class BertIterator implements MultiDataSetIterator { public enum Task {UNSUPERVISED, SEQ_CLASSIFICATION} + public enum LengthHandling {FIXED_LENGTH, ANY_LENGTH, CLIP_ONLY} + public enum FeatureArrays {INDICES_MASK, INDICES_MASK_SEGMENTID} + public enum UnsupervisedLabelFormat {RANK2_IDX, RANK3_NCL, RANK3_LNC} protected Task task; @@ -116,12 +131,13 @@ public class BertIterator implements MultiDataSetIterator { protected int maxTokens = -1; protected int minibatchSize = 32; protected boolean padMinibatches = false; - @Getter @Setter + @Getter + @Setter protected MultiDataSetPreProcessor preProcessor; protected LabeledSentenceProvider sentenceProvider = null; protected LengthHandling lengthHandling; protected FeatureArrays featureArrays; - protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects? + protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects? protected BertSequenceMasker masker = null; protected UnsupervisedLabelFormat unsupervisedLabelFormat = null; protected String maskToken; @@ -130,7 +146,7 @@ public class BertIterator implements MultiDataSetIterator { protected List vocabKeysAsList; - protected BertIterator(Builder b){ + protected BertIterator(Builder b) { this.task = b.task; this.tokenizerFactory = b.tokenizerFactory; this.maxTokens = b.maxTokens; @@ -166,10 +182,10 @@ public class BertIterator implements MultiDataSetIterator { public MultiDataSet next(int num) { Preconditions.checkState(hasNext(), "No next element available"); - List> list = new ArrayList<>(num); - int count = 0; - if(sentenceProvider != null){ - while(sentenceProvider.hasNext() && count++ < num) { + List> list = new ArrayList<>(num); + int mbSize = 0; + if (sentenceProvider != null) { + while (sentenceProvider.hasNext() && mbSize++ < num) { list.add(sentenceProvider.nextSentence()); } } else { @@ -177,41 +193,60 @@ public class BertIterator implements MultiDataSetIterator { throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented"); } - //Get and tokenize the sentences for this minibatch - List, String>> tokenizedSentences = new ArrayList<>(num); - int longestSeq = -1; - for(Pair p : list){ - List tokens = tokenizeSentence(p.getFirst()); - tokenizedSentences.add(new Pair<>(tokens, p.getSecond())); - longestSeq = Math.max(longestSeq, tokens.size()); - } - //Determine output array length... - int outLength; - switch (lengthHandling){ - case FIXED_LENGTH: - outLength = maxTokens; - break; - case ANY_LENGTH: - outLength = longestSeq; - break; - case CLIP_ONLY: - outLength = Math.min(maxTokens, longestSeq); - break; - default: - throw new RuntimeException("Not implemented length handling mode: " + lengthHandling); - } + Pair, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list); + List, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight(); + int outLength = outLTokenizedSentencesPair.getLeft(); - int mb = tokenizedSentences.size(); - int mbPadded = padMinibatches ? minibatchSize : mb; + Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength); + INDArray[] featureArray = featuresAndMaskArraysPair.getFirst(); + INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond(); + + + Pair labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength); + INDArray[] labelArray = labelsAndMaskArraysPair.getFirst(); + INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond(); + + org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(featureArray, labelArray, featureMaskArray, labelMaskArray); + if (preProcessor != null) + preProcessor.preProcess(mds); + + return mds; + } + + + /** + * For use during inference. Will convert a given list of sentences to features and feature masks as appropriate. + * + * @param listOnlySentences + * @return Pair of INDArrays[], first element is feature arrays and the second is the masks array + */ + public Pair featurizeSentences(List listOnlySentences) { + + List> sentencesWithNullLabel = addDummyLabel(listOnlySentences); + + Pair, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel); + List, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight(); + int outLength = outLTokenizedSentencesPair.getLeft(); + + Pair featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength); + if (preProcessor != null) { + MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null); + preProcessor.preProcess(dummyMDS); + return new Pair(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays()); + } + return convertMiniBatchFeatures(tokenizedSentences, outLength); + } + + private Pair convertMiniBatchFeatures(List, String>> tokenizedSentences, int outLength) { + int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size(); int[][] outIdxs = new int[mbPadded][outLength]; int[][] outMask = new int[mbPadded][outLength]; - - for( int i=0; i,String> p = tokenizedSentences.get(i); + for (int i = 0; i < tokenizedSentences.size(); i++) { + Pair, String> p = tokenizedSentences.get(i); List t = p.getFirst(); - for( int j=0; j(f, fm); + } + private Pair, String>>> tokenizeMiniBatch(List> list) { + //Get and tokenize the sentences for this minibatch + List, String>> tokenizedSentences = new ArrayList<>(list.size()); + int longestSeq = -1; + for (Pair p : list) { + List tokens = tokenizeSentence(p.getFirst()); + tokenizedSentences.add(new Pair<>(tokens, p.getSecond())); + longestSeq = Math.max(longestSeq, tokens.size()); + } + + //Determine output array length... + int outLength; + switch (lengthHandling) { + case FIXED_LENGTH: + outLength = maxTokens; + break; + case ANY_LENGTH: + outLength = longestSeq; + break; + case CLIP_ONLY: + outLength = Math.min(maxTokens, longestSeq); + break; + default: + throw new RuntimeException("Not implemented length handling mode: " + lengthHandling); + } + return new Pair<>(outLength, tokenizedSentences); + } + + private Pair convertMiniBatchLabels(List, String>> tokenizedSentences, INDArray[] featureArray, int outLength) { INDArray[] l = new INDArray[1]; INDArray[] lm; - if(task == Task.SEQ_CLASSIFICATION){ + int mbSize = tokenizedSentences.size(); + int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size(); + if (task == Task.SEQ_CLASSIFICATION) { //Sequence classification task: output is 2d, one-hot, shape [minibatch, numClasses] int numClasses; int[] classLabels = new int[mbPadded]; - if(sentenceProvider != null){ + if (sentenceProvider != null) { numClasses = sentenceProvider.numLabelClasses(); List labels = sentenceProvider.allLabels(); - for(int i=0; i= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl); @@ -252,21 +320,21 @@ public class BertIterator implements MultiDataSetIterator { throw new RuntimeException(); } l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses); - for( int i=0; i e : vocabMap.entrySet()){ + for (Map.Entry e : vocabMap.entrySet()) { arr[e.getValue()] = e.getKey(); } vocabKeysAsList = Arrays.asList(arr); @@ -276,31 +344,31 @@ public class BertIterator implements MultiDataSetIterator { int vocabSize = vocabMap.size(); INDArray labelArr; INDArray lMask = Nd4j.zeros(DataType.INT, mbPadded, outLength); - if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){ + if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) { labelArr = Nd4j.create(DataType.INT, mbPadded, outLength); - } else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){ + } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) { labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength); - } else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){ + } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) { labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize); } else { throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat); } - for( int i=0; i tokens = tokenizedSentences.get(i).getFirst(); - Pair,boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList); + Pair, boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList); List maskedTokens = p.getFirst(); boolean[] predictionTarget = p.getSecond(); int seqLen = Math.min(predictionTarget.length, outLength); - for(int j=0; j(l, lm); } private List tokenizeSentence(String sentence) { Tokenizer t = tokenizerFactory.create(sentence); List tokens = new ArrayList<>(); - if(prependToken != null) + if (prependToken != null) tokens.add(prependToken); while (t.hasMoreTokens()) { @@ -341,6 +405,16 @@ public class BertIterator implements MultiDataSetIterator { return tokens; } + + private List> addDummyLabel(List listOnlySentences) { + List> list = new ArrayList<>(listOnlySentences.size()); + for (String s : listOnlySentences) { + list.add(new Pair(s, null)); + } + return list; + } + + @Override public boolean resetSupported() { return true; @@ -353,12 +427,12 @@ public class BertIterator implements MultiDataSetIterator { @Override public void reset() { - if(sentenceProvider != null){ + if (sentenceProvider != null) { sentenceProvider.reset(); } } - public static Builder builder(){ + public static Builder builder() { return new Builder(); } @@ -373,7 +447,7 @@ public class BertIterator implements MultiDataSetIterator { protected MultiDataSetPreProcessor preProcessor; protected LabeledSentenceProvider sentenceProvider = null; protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID; - protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects? + protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects? protected BertSequenceMasker masker = new BertMaskedLMMasker(); protected UnsupervisedLabelFormat unsupervisedLabelFormat; protected String maskToken; @@ -382,7 +456,7 @@ public class BertIterator implements MultiDataSetIterator { /** * Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details. */ - public Builder task(Task task){ + public Builder task(Task task) { this.task = task; return this; } @@ -392,18 +466,19 @@ public class BertIterator implements MultiDataSetIterator { * For BERT, typically {@link org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory} * is used */ - public Builder tokenizer(TokenizerFactory tokenizerFactory){ + public Builder tokenizer(TokenizerFactory tokenizerFactory) { this.tokenizerFactory = tokenizerFactory; return this; } /** * Specifies how the sequence length of the output data should be handled. See {@link BertIterator} for more details. - * @param lengthHandling Length handling - * @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH} + * + * @param lengthHandling Length handling + * @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH} * @return */ - public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength){ + public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength) { this.lengthHandling = lengthHandling; this.maxTokens = maxLength; return this; @@ -412,9 +487,10 @@ public class BertIterator implements MultiDataSetIterator { /** * Minibatch size to use (number of examples to train on for each iteration) * See also: {@link #padMinibatches} - * @param minibatchSize Minibatch size + * + * @param minibatchSize Minibatch size */ - public Builder minibatchSize(int minibatchSize){ + public Builder minibatchSize(int minibatchSize) { this.minibatchSize = minibatchSize; return this; } @@ -429,7 +505,7 @@ public class BertIterator implements MultiDataSetIterator { * Both options should result in exactly the same model. However, some BERT implementations may require exactly an * exact number of examples in all minibatches to function. */ - public Builder padMinibatches(boolean padMinibatches){ + public Builder padMinibatches(boolean padMinibatches) { this.padMinibatches = padMinibatches; return this; } @@ -437,7 +513,7 @@ public class BertIterator implements MultiDataSetIterator { /** * Set the preprocessor to be used on the MultiDataSets before returning them. Default: none (null) */ - public Builder preProcessor(MultiDataSetPreProcessor preProcessor){ + public Builder preProcessor(MultiDataSetPreProcessor preProcessor) { this.preProcessor = preProcessor; return this; } @@ -446,7 +522,7 @@ public class BertIterator implements MultiDataSetIterator { * Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised * use case, the labels will be ignored. */ - public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider){ + public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) { this.sentenceProvider = sentenceProvider; return this; } @@ -454,7 +530,7 @@ public class BertIterator implements MultiDataSetIterator { /** * Specify what arrays should be returned. See {@link BertIterator} for more details. */ - public Builder featureArrays(FeatureArrays featureArrays){ + public Builder featureArrays(FeatureArrays featureArrays) { this.featureArrays = featureArrays; return this; } @@ -465,7 +541,7 @@ public class BertIterator implements MultiDataSetIterator { * If using {@link BertWordPieceTokenizerFactory}, * this can be obtained using {@link BertWordPieceTokenizerFactory#getVocab()} */ - public Builder vocabMap(Map vocabMap){ + public Builder vocabMap(Map vocabMap) { this.vocabMap = vocabMap; return this; } @@ -475,7 +551,7 @@ public class BertIterator implements MultiDataSetIterator { * masked language model. This can be used to customize how the masking is performed.
* Default: {@link BertMaskedLMMasker} */ - public Builder masker(BertSequenceMasker masker){ + public Builder masker(BertSequenceMasker masker) { this.masker = masker; return this; } @@ -485,7 +561,7 @@ public class BertIterator implements MultiDataSetIterator { * masked language model. Used to specify the format that the labels should be returned in. * See {@link BertIterator} for more details. */ - public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat){ + public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat) { this.unsupervisedLabelFormat = labelFormat; return this; } @@ -497,7 +573,7 @@ public class BertIterator implements MultiDataSetIterator { * the exact behaviour will depend on what masker is used.
* Note that this must be in the vocabulary map set in {@link #vocabMap} */ - public Builder maskToken(String maskToken){ + public Builder maskToken(String maskToken) { this.maskToken = maskToken; return this; } @@ -510,12 +586,12 @@ public class BertIterator implements MultiDataSetIterator { * * @param prependToken The token to start each sequence with (null: no token will be prepended) */ - public Builder prependToken(String prependToken){ + public Builder prependToken(String prependToken) { this.prependToken = prependToken; return this; } - public BertIterator build(){ + public BertIterator build() { Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed"); Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required"); Preconditions.checkState(vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map) to set"); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java index 90879f858..d4be5e352 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java @@ -26,7 +26,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.primitives.Pair; import org.nd4j.resources.Resources; @@ -34,10 +33,7 @@ import java.io.File; import java.io.IOException; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Random; +import java.util.*; import static org.junit.Assert.*; @@ -54,6 +50,9 @@ public class TestBertIterator extends BaseDL4JTest { String toTokenize1 = "I saw a girl with a telescope."; String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + List forInference = new ArrayList<>(); + forInference.add(toTokenize1); + forInference.add(toTokenize2); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); BertIterator b = BertIterator.builder() @@ -100,12 +99,15 @@ public class TestBertIterator extends BaseDL4JTest { assertEquals(expF, mds.getFeatures(0)); assertEquals(expM, mds.getFeaturesMaskArray(0)); + assertEquals(expF,b.featurizeSentences(forInference).getFirst()[0]); + assertEquals(expM,b.featurizeSentences(forInference).getSecond()[0]); assertFalse(b.hasNext()); b.reset(); assertTrue(b.hasNext()); MultiDataSet mds2 = b.next(); + forInference.set(0,toTokenize2); //Same thing, but with segment ID also b = BertIterator.builder() .tokenizer(t) @@ -118,9 +120,11 @@ public class TestBertIterator extends BaseDL4JTest { .build(); mds = b.next(); assertEquals(2, mds.getFeatures().length); + assertEquals(2,b.featurizeSentences(forInference).getFirst().length); //Segment ID should be all 0s for single segment task INDArray segmentId = expM.like(); assertEquals(segmentId, mds.getFeatures(1)); + assertEquals(segmentId,b.featurizeSentences(forInference).getFirst()[1]); } @Test(timeout = 20000L) @@ -157,6 +161,9 @@ public class TestBertIterator extends BaseDL4JTest { public void testLengthHandling() throws Exception { String toTokenize1 = "I saw a girl with a telescope."; String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + List forInference = new ArrayList<>(); + forInference.add(toTokenize1); + forInference.add(toTokenize2); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); @@ -205,6 +212,8 @@ public class TestBertIterator extends BaseDL4JTest { assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeatures(0)); assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeaturesMaskArray(0)); + assertEquals(mds.getFeatures(0),b.featurizeSentences(forInference).getFirst()[0]); + assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]); //Clip only: clip to maximum, but don't pad if less b = BertIterator.builder() @@ -227,6 +236,9 @@ public class TestBertIterator extends BaseDL4JTest { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); String toTokenize1 = "I saw a girl with a telescope."; String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + List forInference = new ArrayList<>(); + forInference.add(toTokenize1); + forInference.add(toTokenize2); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); @@ -288,6 +300,9 @@ public class TestBertIterator extends BaseDL4JTest { assertEquals(expM, mds.getFeaturesMaskArray(0)); assertEquals(expL, mds.getLabels(0)); assertEquals(expLM, mds.getLabelsMaskArray(0)); + + assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]); + assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); } private static class TestSentenceProvider implements LabeledSentenceProvider {