DL4J BERT iterator: Add option to prepend token (#39)
* Add option to prepend token Signed-off-by: Alex Black <blacka101@gmail.com> * Javadoc Signed-off-by: Alex Black <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									ef8cb55b7b
								
							
						
					
					
						commit
						79d1f02ee1
					
				@ -125,6 +125,7 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
    protected BertSequenceMasker masker = null;
 | 
			
		||||
    protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
 | 
			
		||||
    protected String maskToken;
 | 
			
		||||
    protected String prependToken;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    protected List<String> vocabKeysAsList;
 | 
			
		||||
@ -143,6 +144,7 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
        this.masker = b.masker;
 | 
			
		||||
        this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
 | 
			
		||||
        this.maskToken = b.maskToken;
 | 
			
		||||
        this.prependToken = b.prependToken;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
@ -329,6 +331,9 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
        Tokenizer t = tokenizerFactory.create(sentence);
 | 
			
		||||
 | 
			
		||||
        List<String> tokens = new ArrayList<>();
 | 
			
		||||
        if(prependToken != null)
 | 
			
		||||
            tokens.add(prependToken);
 | 
			
		||||
 | 
			
		||||
        while (t.hasMoreTokens()) {
 | 
			
		||||
            String token = t.nextToken();
 | 
			
		||||
            tokens.add(token);
 | 
			
		||||
@ -372,6 +377,7 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
        protected BertSequenceMasker masker = new BertMaskedLMMasker();
 | 
			
		||||
        protected UnsupervisedLabelFormat unsupervisedLabelFormat;
 | 
			
		||||
        protected String maskToken;
 | 
			
		||||
        protected String prependToken;
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
 | 
			
		||||
@ -496,6 +502,19 @@ public class BertIterator implements MultiDataSetIterator {
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Prepend the specified token to the sequences, when doing supervised training.<br>
 | 
			
		||||
         * i.e., any token sequences will have this added at the start.<br>
 | 
			
		||||
         * Some BERT/Transformer models may need this - for example sequences starting with a "[CLS]" token.<br>
 | 
			
		||||
         * No token is prepended by default.
 | 
			
		||||
         *
 | 
			
		||||
         * @param prependToken The token to start each sequence with (null: no token will be prepended)
 | 
			
		||||
         */
 | 
			
		||||
        public Builder prependToken(String prependToken){
 | 
			
		||||
            this.prependToken = prependToken;
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        public BertIterator build(){
 | 
			
		||||
            Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
 | 
			
		||||
            Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user