diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java index afa146c18..c6a88ffb3 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java @@ -125,6 +125,7 @@ public class BertIterator implements MultiDataSetIterator { protected BertSequenceMasker masker = null; protected UnsupervisedLabelFormat unsupervisedLabelFormat = null; protected String maskToken; + protected String prependToken; protected List vocabKeysAsList; @@ -143,6 +144,7 @@ public class BertIterator implements MultiDataSetIterator { this.masker = b.masker; this.unsupervisedLabelFormat = b.unsupervisedLabelFormat; this.maskToken = b.maskToken; + this.prependToken = b.prependToken; } @Override @@ -329,6 +331,9 @@ public class BertIterator implements MultiDataSetIterator { Tokenizer t = tokenizerFactory.create(sentence); List tokens = new ArrayList<>(); + if(prependToken != null) + tokens.add(prependToken); + while (t.hasMoreTokens()) { String token = t.nextToken(); tokens.add(token); @@ -372,6 +377,7 @@ public class BertIterator implements MultiDataSetIterator { protected BertSequenceMasker masker = new BertMaskedLMMasker(); protected UnsupervisedLabelFormat unsupervisedLabelFormat; protected String maskToken; + protected String prependToken; /** * Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details. @@ -496,6 +502,19 @@ public class BertIterator implements MultiDataSetIterator { return this; } + /** + * Prepend the specified token to the sequences, when doing supervised training.
+ * i.e., any token sequences will have this added at the start.
+ * Some BERT/Transformer models may need this - for example sequences starting with a "[CLS]" token.
+ * No token is prepended by default. + * + * @param prependToken The token to start each sequence with (null: no token will be prepended) + */ + public Builder prependToken(String prependToken){ + this.prependToken = prependToken; + return this; + } + public BertIterator build(){ Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed"); Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");