From c28372cb490a80a613e7258a977b63a5488f494c Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 28 Jun 2019 21:53:05 +1000 Subject: [PATCH] BERT tokenization fixes (#35) * Add composite token preprocessor Signed-off-by: AlexDBlack * Fix case issue with bert tokenization Signed-off-by: AlexDBlack --- .../BertWordPiecePreProcessor.java | 34 +++++++++++++--- .../preprocessor/CompositePreProcessor.java | 38 ++++++++++++++++++ .../BertWordPieceTokenizerTests.java | 40 +++++++++++++++++-- 3 files changed, 103 insertions(+), 9 deletions(-) create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/CompositePreProcessor.java diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/BertWordPiecePreProcessor.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/BertWordPiecePreProcessor.java index 14b5775ad..b46ab164a 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/BertWordPiecePreProcessor.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/BertWordPiecePreProcessor.java @@ -5,6 +5,7 @@ import it.unimi.dsi.fastutil.ints.IntSet; import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess; import java.text.Normalizer; +import java.util.List; import java.util.Map; /** @@ -70,6 +71,11 @@ public class BertWordPiecePreProcessor implements TokenPreProcess { if(cp == 0 || cp == REPLACEMENT_CHAR || isControlCharacter(cp) || (stripAccents && Character.getType(cp) == Character.NON_SPACING_MARK)) continue; + //Convert to lower case if necessary + if(lowerCase){ + cp = Character.toLowerCase(cp); + } + //Replace whitespace chars with space if(isWhiteSpace(cp)) { sb.append(' '); @@ -89,11 +95,6 @@ public class BertWordPiecePreProcessor implements TokenPreProcess { continue; } - //Convert to lower case if necessary - if(lowerCase){ - cp = Character.toLowerCase(cp); - } - //All other characters - keep sb.appendCodePoint(cp); } @@ -129,4 +130,27 @@ public class BertWordPiecePreProcessor implements TokenPreProcess { (cp >= 0xF900 && cp <= 0xFAFF) || (cp >= 0x2F800 && cp <= 0x2FA1F); } + + + /** + * Reconstruct the String from tokens + * @param tokens + * @return + */ + public static String reconstructFromTokens(List tokens){ + StringBuilder sb = new StringBuilder(); + boolean first = true; + for(String s : tokens){ + if(s.startsWith("##")){ + sb.append(s.substring(2)); + } else { + if(!first && !".".equals(s)) + sb.append(" "); + sb.append(s); + first = false; +// } + } + } + return sb.toString(); + } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/CompositePreProcessor.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/CompositePreProcessor.java new file mode 100644 index 000000000..70dbd78fb --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/CompositePreProcessor.java @@ -0,0 +1,38 @@ +package org.deeplearning4j.text.tokenization.tokenizer.preprocessor; + +import lombok.NonNull; +import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess; +import org.nd4j.base.Preconditions; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +/** + * CompositePreProcessor is a {@link TokenPreProcess} that applies multiple preprocessors sequentially + * @author Alex Black + */ +public class CompositePreProcessor implements TokenPreProcess { + + private List preProcessors; + + public CompositePreProcessor(@NonNull TokenPreProcess... preProcessors){ + Preconditions.checkState(preProcessors.length > 0, "No preprocessors were specified (empty input)"); + this.preProcessors = Arrays.asList(preProcessors); + } + + public CompositePreProcessor(@NonNull Collection preProcessors){ + Preconditions.checkState(!preProcessors.isEmpty(), "No preprocessors were specified (empty input)"); + this.preProcessors = new ArrayList<>(preProcessors); + } + + @Override + public String preProcess(String token) { + String s = token; + for(TokenPreProcess tpp : preProcessors){ + s = tpp.preProcess(s); + } + return s; + } +} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java index e531f9a0a..4b78e51a2 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java @@ -19,15 +19,13 @@ package org.deeplearning4j.text.tokenization.tokenizer; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.LowCasePreProcessor; +import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.BertWordPiecePreProcessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.resources.Resources; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.ByteArrayInputStream; import java.io.File; @@ -61,6 +59,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { log.info("Position: [" + position + "], token1: '" + tok1 + "', token 2: '" + tok2 + "'"); position++; assertEquals(tok1, tok2); + + String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens()); + assertEquals(toTokenize, s2); } } @@ -76,7 +77,6 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { } @Test - @Ignore("AB 2019/05/24 - Disabled until dev branch merged - see issue #7657") public void testBertWordPieceTokenizer3() throws Exception { String toTokenize = "Donaudampfschifffahrtskapitänsmützeninnenfuttersaum"; TokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); @@ -86,6 +86,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { final List expected = Arrays.asList("Donau", "##dam", "##pf", "##schiff", "##fahrt", "##skap", "##itä", "##ns", "##m", "##ützen", "##innen", "##fu", "##tter", "##sa", "##um"); assertEquals(expected, tokenizer.getTokens()); assertEquals(expected, tokenizer2.getTokens()); + + String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens()); + assertEquals(toTokenize, s2); } @Test @@ -98,6 +101,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { final List expected = Arrays.asList("I", "saw", "a", "girl", "with", "a", "tele", "##scope", "."); assertEquals(expected, tokenizer.getTokens()); assertEquals(expected, tokenizer2.getTokens()); + + String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens()); + assertEquals(toTokenize, s2); } @Test @@ -112,6 +118,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { final List expected = Arrays.asList("Donau", "##dam", "##pf", "##schiff", "##fahrt", "##s", "Kapitän", "##sm", "##ützen", "##innen", "##fu", "##tter", "##sa", "##um"); assertEquals(expected, tokenizer.getTokens()); assertEquals(expected, tokenizer2.getTokens()); + + String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens()); + assertEquals(toTokenize, s2); } @Test @@ -125,6 +134,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { final List expected = Arrays.asList("i", "saw", "a", "girl", "with", "a", "tele", "##scope", "."); assertEquals(expected, tokenizer.getTokens()); assertEquals(expected, tokenizer2.getTokens()); + + String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens()); + assertEquals(toTokenize.toLowerCase(), s2); } @Test @@ -138,6 +150,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { final List expected = Arrays.asList("i", "saw", "a", "girl", "with", "a", "tele", "##scope", "."); assertEquals(expected, tokenizer.getTokens()); assertEquals(expected, tokenizer2.getTokens()); + + String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens()); + assertEquals(toTokenize.toLowerCase(), s2); } @Test @@ -188,4 +203,21 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { } } } + + + @Test + public void testBertWordPieceTokenizer10() throws Exception { + File f = Resources.asFile("deeplearning4j-nlp/bert/uncased_L-12_H-768_A-12/vocab.txt"); + BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(f, true, true, StandardCharsets.UTF_8); + + String s = "This is a sentence with Multiple Cases For Words. It should be coverted to Lower Case here."; + + Tokenizer tokenizer = t.create(s); + List list = tokenizer.getTokens(); + System.out.println(list); + + String s2 = BertWordPiecePreProcessor.reconstructFromTokens(list); + String exp = s.toLowerCase(); + assertEquals(exp, s2); + } }