BERT tokenization fixes (#35)

* Add composite token preprocessor

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix case issue with bert tokenization

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-06-28 21:53:05 +10:00 committed by AlexDBlack
parent cc6063402e
commit c28372cb49
3 changed files with 103 additions and 9 deletions

View File

@ -5,6 +5,7 @@ import it.unimi.dsi.fastutil.ints.IntSet;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import java.text.Normalizer;
import java.util.List;
import java.util.Map;
/**
@ -70,6 +71,11 @@ public class BertWordPiecePreProcessor implements TokenPreProcess {
if(cp == 0 || cp == REPLACEMENT_CHAR || isControlCharacter(cp) || (stripAccents && Character.getType(cp) == Character.NON_SPACING_MARK))
continue;
//Convert to lower case if necessary
if(lowerCase){
cp = Character.toLowerCase(cp);
}
//Replace whitespace chars with space
if(isWhiteSpace(cp)) {
sb.append(' ');
@ -89,11 +95,6 @@ public class BertWordPiecePreProcessor implements TokenPreProcess {
continue;
}
//Convert to lower case if necessary
if(lowerCase){
cp = Character.toLowerCase(cp);
}
//All other characters - keep
sb.appendCodePoint(cp);
}
@ -129,4 +130,27 @@ public class BertWordPiecePreProcessor implements TokenPreProcess {
(cp >= 0xF900 && cp <= 0xFAFF) ||
(cp >= 0x2F800 && cp <= 0x2FA1F);
}
/**
* Reconstruct the String from tokens
* @param tokens
* @return
*/
public static String reconstructFromTokens(List<String> tokens){
StringBuilder sb = new StringBuilder();
boolean first = true;
for(String s : tokens){
if(s.startsWith("##")){
sb.append(s.substring(2));
} else {
if(!first && !".".equals(s))
sb.append(" ");
sb.append(s);
first = false;
// }
}
}
return sb.toString();
}
}

View File

@ -0,0 +1,38 @@
package org.deeplearning4j.text.tokenization.tokenizer.preprocessor;
import lombok.NonNull;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.nd4j.base.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
/**
* CompositePreProcessor is a {@link TokenPreProcess} that applies multiple preprocessors sequentially
* @author Alex Black
*/
public class CompositePreProcessor implements TokenPreProcess {
private List<TokenPreProcess> preProcessors;
public CompositePreProcessor(@NonNull TokenPreProcess... preProcessors){
Preconditions.checkState(preProcessors.length > 0, "No preprocessors were specified (empty input)");
this.preProcessors = Arrays.asList(preProcessors);
}
public CompositePreProcessor(@NonNull Collection<? extends TokenPreProcess> preProcessors){
Preconditions.checkState(!preProcessors.isEmpty(), "No preprocessors were specified (empty input)");
this.preProcessors = new ArrayList<>(preProcessors);
}
@Override
public String preProcess(String token) {
String s = token;
for(TokenPreProcess tpp : preProcessors){
s = tpp.preProcess(s);
}
return s;
}
}

View File

@ -19,15 +19,13 @@ package org.deeplearning4j.text.tokenization.tokenizer;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.LowCasePreProcessor;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.BertWordPiecePreProcessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayInputStream;
import java.io.File;
@ -61,6 +59,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
log.info("Position: [" + position + "], token1: '" + tok1 + "', token 2: '" + tok2 + "'");
position++;
assertEquals(tok1, tok2);
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
assertEquals(toTokenize, s2);
}
}
@ -76,7 +77,6 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
}
@Test
@Ignore("AB 2019/05/24 - Disabled until dev branch merged - see issue #7657")
public void testBertWordPieceTokenizer3() throws Exception {
String toTokenize = "Donaudampfschifffahrtskapitänsmützeninnenfuttersaum";
TokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
@ -86,6 +86,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
final List<String> expected = Arrays.asList("Donau", "##dam", "##pf", "##schiff", "##fahrt", "##skap", "##itä", "##ns", "##m", "##ützen", "##innen", "##fu", "##tter", "##sa", "##um");
assertEquals(expected, tokenizer.getTokens());
assertEquals(expected, tokenizer2.getTokens());
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
assertEquals(toTokenize, s2);
}
@Test
@ -98,6 +101,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
final List<String> expected = Arrays.asList("I", "saw", "a", "girl", "with", "a", "tele", "##scope", ".");
assertEquals(expected, tokenizer.getTokens());
assertEquals(expected, tokenizer2.getTokens());
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
assertEquals(toTokenize, s2);
}
@Test
@ -112,6 +118,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
final List<String> expected = Arrays.asList("Donau", "##dam", "##pf", "##schiff", "##fahrt", "##s", "Kapitän", "##sm", "##ützen", "##innen", "##fu", "##tter", "##sa", "##um");
assertEquals(expected, tokenizer.getTokens());
assertEquals(expected, tokenizer2.getTokens());
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
assertEquals(toTokenize, s2);
}
@Test
@ -125,6 +134,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
final List<String> expected = Arrays.asList("i", "saw", "a", "girl", "with", "a", "tele", "##scope", ".");
assertEquals(expected, tokenizer.getTokens());
assertEquals(expected, tokenizer2.getTokens());
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
assertEquals(toTokenize.toLowerCase(), s2);
}
@Test
@ -138,6 +150,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
final List<String> expected = Arrays.asList("i", "saw", "a", "girl", "with", "a", "tele", "##scope", ".");
assertEquals(expected, tokenizer.getTokens());
assertEquals(expected, tokenizer2.getTokens());
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
assertEquals(toTokenize.toLowerCase(), s2);
}
@Test
@ -188,4 +203,21 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
}
}
}
@Test
public void testBertWordPieceTokenizer10() throws Exception {
File f = Resources.asFile("deeplearning4j-nlp/bert/uncased_L-12_H-768_A-12/vocab.txt");
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(f, true, true, StandardCharsets.UTF_8);
String s = "This is a sentence with Multiple Cases For Words. It should be coverted to Lower Case here.";
Tokenizer tokenizer = t.create(s);
List<String> list = tokenizer.getTokens();
System.out.println(list);
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(list);
String exp = s.toLowerCase();
assertEquals(exp, s2);
}
}