Fix BERT word piece tokenizer stack overflow error (#205)
* Change the regular expression for the Bert tokenizer. The previous regular expression causes StackOverflowErrors if given a document with a large amount of whitespace. I believe that the one I've provided is an equivalent. * Add test for new BertWordPieceTokenizer RegEx. This test should cause a StackOverflowError with the previous version. * Fix assert off by one.master
parent
8a0d5e3b97
commit
ebeeb8bc48
|
@ -29,7 +29,7 @@ import java.util.regex.Pattern;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class BertWordPieceTokenizer implements Tokenizer {
|
public class BertWordPieceTokenizer implements Tokenizer {
|
||||||
public static final Pattern splitPattern = Pattern.compile("(\\p{javaWhitespace}|((?<=\\p{Punct})|(?=\\p{Punct})))+");
|
public static final Pattern splitPattern = Pattern.compile("\\p{javaWhitespace}+|((?<=\\p{Punct})+|(?=\\p{Punct}+))");
|
||||||
|
|
||||||
private final List<String> tokens;
|
private final List<String> tokens;
|
||||||
private final TokenPreProcess preTokenizePreProcessor;
|
private final TokenPreProcess preTokenizePreProcessor;
|
||||||
|
|
|
@ -220,4 +220,23 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
||||||
String exp = s.toLowerCase();
|
String exp = s.toLowerCase();
|
||||||
assertEquals(exp, s2);
|
assertEquals(exp, s2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTokenizerHandlesLargeContiguousWhitespace() throws Exception {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
sb.append("apple.");
|
||||||
|
for (int i = 0; i < 10000; i++) {
|
||||||
|
sb.append(" ");
|
||||||
|
}
|
||||||
|
sb.append(".pen. .pineapple");
|
||||||
|
|
||||||
|
File f = Resources.asFile("deeplearning4j-nlp/bert/uncased_L-12_H-768_A-12/vocab.txt");
|
||||||
|
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(f, true, true, StandardCharsets.UTF_8);
|
||||||
|
|
||||||
|
Tokenizer tokenizer = t.create(sb.toString());
|
||||||
|
List<String> list = tokenizer.getTokens();
|
||||||
|
System.out.println(list);
|
||||||
|
|
||||||
|
assertEquals(8, list.size());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue