DL4J BERT iterator: Add option to prepend token (#39)
* Add option to prepend token Signed-off-by: Alex Black <blacka101@gmail.com> * Javadoc Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
ef8cb55b7b
commit
79d1f02ee1
|
@ -125,6 +125,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
protected BertSequenceMasker masker = null;
|
protected BertSequenceMasker masker = null;
|
||||||
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
|
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
|
||||||
protected String maskToken;
|
protected String maskToken;
|
||||||
|
protected String prependToken;
|
||||||
|
|
||||||
|
|
||||||
protected List<String> vocabKeysAsList;
|
protected List<String> vocabKeysAsList;
|
||||||
|
@ -143,6 +144,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
this.masker = b.masker;
|
this.masker = b.masker;
|
||||||
this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
|
this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
|
||||||
this.maskToken = b.maskToken;
|
this.maskToken = b.maskToken;
|
||||||
|
this.prependToken = b.prependToken;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -329,6 +331,9 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
Tokenizer t = tokenizerFactory.create(sentence);
|
Tokenizer t = tokenizerFactory.create(sentence);
|
||||||
|
|
||||||
List<String> tokens = new ArrayList<>();
|
List<String> tokens = new ArrayList<>();
|
||||||
|
if(prependToken != null)
|
||||||
|
tokens.add(prependToken);
|
||||||
|
|
||||||
while (t.hasMoreTokens()) {
|
while (t.hasMoreTokens()) {
|
||||||
String token = t.nextToken();
|
String token = t.nextToken();
|
||||||
tokens.add(token);
|
tokens.add(token);
|
||||||
|
@ -372,6 +377,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
protected BertSequenceMasker masker = new BertMaskedLMMasker();
|
protected BertSequenceMasker masker = new BertMaskedLMMasker();
|
||||||
protected UnsupervisedLabelFormat unsupervisedLabelFormat;
|
protected UnsupervisedLabelFormat unsupervisedLabelFormat;
|
||||||
protected String maskToken;
|
protected String maskToken;
|
||||||
|
protected String prependToken;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
|
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
|
||||||
|
@ -496,6 +502,19 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Prepend the specified token to the sequences, when doing supervised training.<br>
|
||||||
|
* i.e., any token sequences will have this added at the start.<br>
|
||||||
|
* Some BERT/Transformer models may need this - for example sequences starting with a "[CLS]" token.<br>
|
||||||
|
* No token is prepended by default.
|
||||||
|
*
|
||||||
|
* @param prependToken The token to start each sequence with (null: no token will be prepended)
|
||||||
|
*/
|
||||||
|
public Builder prependToken(String prependToken){
|
||||||
|
this.prependToken = prependToken;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public BertIterator build(){
|
public BertIterator build(){
|
||||||
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
|
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
|
||||||
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
|
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
|
||||||
|
|
Loading…
Reference in New Issue