DL4J BERT iterator: Add option to prepend token (#39)

* Add option to prepend token

Signed-off-by: Alex Black <blacka101@gmail.com>

* Javadoc

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-07-02 18:23:35 +10:00 committed by AlexDBlack
parent ef8cb55b7b
commit 79d1f02ee1
1 changed files with 19 additions and 0 deletions

View File

@ -125,6 +125,7 @@ public class BertIterator implements MultiDataSetIterator {
protected BertSequenceMasker masker = null;
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
protected String maskToken;
protected String prependToken;
protected List<String> vocabKeysAsList;
@ -143,6 +144,7 @@ public class BertIterator implements MultiDataSetIterator {
this.masker = b.masker;
this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
this.maskToken = b.maskToken;
this.prependToken = b.prependToken;
}
@Override
@ -329,6 +331,9 @@ public class BertIterator implements MultiDataSetIterator {
Tokenizer t = tokenizerFactory.create(sentence);
List<String> tokens = new ArrayList<>();
if(prependToken != null)
tokens.add(prependToken);
while (t.hasMoreTokens()) {
String token = t.nextToken();
tokens.add(token);
@ -372,6 +377,7 @@ public class BertIterator implements MultiDataSetIterator {
protected BertSequenceMasker masker = new BertMaskedLMMasker();
protected UnsupervisedLabelFormat unsupervisedLabelFormat;
protected String maskToken;
protected String prependToken;
/**
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
@ -496,6 +502,19 @@ public class BertIterator implements MultiDataSetIterator {
return this;
}
/**
* Prepend the specified token to the sequences, when doing supervised training.<br>
* i.e., any token sequences will have this added at the start.<br>
* Some BERT/Transformer models may need this - for example sequences starting with a "[CLS]" token.<br>
* No token is prepended by default.
*
* @param prependToken The token to start each sequence with (null: no token will be prepended)
*/
public Builder prependToken(String prependToken){
this.prependToken = prependToken;
return this;
}
public BertIterator build(){
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");