DL4J BERT iterator: Add option to prepend token (#39)
* Add option to prepend token Signed-off-by: Alex Black <blacka101@gmail.com> * Javadoc Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
ef8cb55b7b
commit
79d1f02ee1
|
@ -125,6 +125,7 @@ public class BertIterator implements MultiDataSetIterator {
|
|||
protected BertSequenceMasker masker = null;
|
||||
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
|
||||
protected String maskToken;
|
||||
protected String prependToken;
|
||||
|
||||
|
||||
protected List<String> vocabKeysAsList;
|
||||
|
@ -143,6 +144,7 @@ public class BertIterator implements MultiDataSetIterator {
|
|||
this.masker = b.masker;
|
||||
this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
|
||||
this.maskToken = b.maskToken;
|
||||
this.prependToken = b.prependToken;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -329,6 +331,9 @@ public class BertIterator implements MultiDataSetIterator {
|
|||
Tokenizer t = tokenizerFactory.create(sentence);
|
||||
|
||||
List<String> tokens = new ArrayList<>();
|
||||
if(prependToken != null)
|
||||
tokens.add(prependToken);
|
||||
|
||||
while (t.hasMoreTokens()) {
|
||||
String token = t.nextToken();
|
||||
tokens.add(token);
|
||||
|
@ -372,6 +377,7 @@ public class BertIterator implements MultiDataSetIterator {
|
|||
protected BertSequenceMasker masker = new BertMaskedLMMasker();
|
||||
protected UnsupervisedLabelFormat unsupervisedLabelFormat;
|
||||
protected String maskToken;
|
||||
protected String prependToken;
|
||||
|
||||
/**
|
||||
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
|
||||
|
@ -496,6 +502,19 @@ public class BertIterator implements MultiDataSetIterator {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepend the specified token to the sequences, when doing supervised training.<br>
|
||||
* i.e., any token sequences will have this added at the start.<br>
|
||||
* Some BERT/Transformer models may need this - for example sequences starting with a "[CLS]" token.<br>
|
||||
* No token is prepended by default.
|
||||
*
|
||||
* @param prependToken The token to start each sequence with (null: no token will be prepended)
|
||||
*/
|
||||
public Builder prependToken(String prependToken){
|
||||
this.prependToken = prependToken;
|
||||
return this;
|
||||
}
|
||||
|
||||
public BertIterator build(){
|
||||
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
|
||||
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
|
||||
|
|
Loading…
Reference in New Issue