From ebeeb8bc4893bff1ac3c21065be4a8cd404c7252 Mon Sep 17 00:00:00 2001 From: Eduardo Gonzalez Date: Mon, 10 Feb 2020 12:33:04 +0900 Subject: [PATCH] Fix BERT word piece tokenizer stack overflow error (#205) * Change the regular expression for the Bert tokenizer. The previous regular expression causes StackOverflowErrors if given a document with a large amount of whitespace. I believe that the one I've provided is an equivalent. * Add test for new BertWordPieceTokenizer RegEx. This test should cause a StackOverflowError with the previous version. * Fix assert off by one. --- .../tokenizer/BertWordPieceTokenizer.java | 2 +- .../BertWordPieceTokenizerTests.java | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizer.java index 0f9c3ec93..817f8c563 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizer.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizer.java @@ -29,7 +29,7 @@ import java.util.regex.Pattern; */ @Slf4j public class BertWordPieceTokenizer implements Tokenizer { - public static final Pattern splitPattern = Pattern.compile("(\\p{javaWhitespace}|((?<=\\p{Punct})|(?=\\p{Punct})))+"); + public static final Pattern splitPattern = Pattern.compile("\\p{javaWhitespace}+|((?<=\\p{Punct})+|(?=\\p{Punct}+))"); private final List tokens; private final TokenPreProcess preTokenizePreProcessor; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java index 80570ae54..a225230af 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java @@ -220,4 +220,23 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { String exp = s.toLowerCase(); assertEquals(exp, s2); } + + @Test + public void testTokenizerHandlesLargeContiguousWhitespace() throws Exception { + StringBuilder sb = new StringBuilder(); + sb.append("apple."); + for (int i = 0; i < 10000; i++) { + sb.append(" "); + } + sb.append(".pen. .pineapple"); + + File f = Resources.asFile("deeplearning4j-nlp/bert/uncased_L-12_H-768_A-12/vocab.txt"); + BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(f, true, true, StandardCharsets.UTF_8); + + Tokenizer tokenizer = t.create(sb.toString()); + List list = tokenizer.getTokens(); + System.out.println(list); + + assertEquals(8, list.size()); + } }