BERT tokenization fixes (#35)
* Add composite token preprocessor Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix case issue with bert tokenization Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
cc6063402e
commit
c28372cb49
|
@ -5,6 +5,7 @@ import it.unimi.dsi.fastutil.ints.IntSet;
|
|||
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
||||
|
||||
import java.text.Normalizer;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
|
@ -70,6 +71,11 @@ public class BertWordPiecePreProcessor implements TokenPreProcess {
|
|||
if(cp == 0 || cp == REPLACEMENT_CHAR || isControlCharacter(cp) || (stripAccents && Character.getType(cp) == Character.NON_SPACING_MARK))
|
||||
continue;
|
||||
|
||||
//Convert to lower case if necessary
|
||||
if(lowerCase){
|
||||
cp = Character.toLowerCase(cp);
|
||||
}
|
||||
|
||||
//Replace whitespace chars with space
|
||||
if(isWhiteSpace(cp)) {
|
||||
sb.append(' ');
|
||||
|
@ -89,11 +95,6 @@ public class BertWordPiecePreProcessor implements TokenPreProcess {
|
|||
continue;
|
||||
}
|
||||
|
||||
//Convert to lower case if necessary
|
||||
if(lowerCase){
|
||||
cp = Character.toLowerCase(cp);
|
||||
}
|
||||
|
||||
//All other characters - keep
|
||||
sb.appendCodePoint(cp);
|
||||
}
|
||||
|
@ -129,4 +130,27 @@ public class BertWordPiecePreProcessor implements TokenPreProcess {
|
|||
(cp >= 0xF900 && cp <= 0xFAFF) ||
|
||||
(cp >= 0x2F800 && cp <= 0x2FA1F);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Reconstruct the String from tokens
|
||||
* @param tokens
|
||||
* @return
|
||||
*/
|
||||
public static String reconstructFromTokens(List<String> tokens){
|
||||
StringBuilder sb = new StringBuilder();
|
||||
boolean first = true;
|
||||
for(String s : tokens){
|
||||
if(s.startsWith("##")){
|
||||
sb.append(s.substring(2));
|
||||
} else {
|
||||
if(!first && !".".equals(s))
|
||||
sb.append(" ");
|
||||
sb.append(s);
|
||||
first = false;
|
||||
// }
|
||||
}
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
package org.deeplearning4j.text.tokenization.tokenizer.preprocessor;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* CompositePreProcessor is a {@link TokenPreProcess} that applies multiple preprocessors sequentially
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class CompositePreProcessor implements TokenPreProcess {
|
||||
|
||||
private List<TokenPreProcess> preProcessors;
|
||||
|
||||
public CompositePreProcessor(@NonNull TokenPreProcess... preProcessors){
|
||||
Preconditions.checkState(preProcessors.length > 0, "No preprocessors were specified (empty input)");
|
||||
this.preProcessors = Arrays.asList(preProcessors);
|
||||
}
|
||||
|
||||
public CompositePreProcessor(@NonNull Collection<? extends TokenPreProcess> preProcessors){
|
||||
Preconditions.checkState(!preProcessors.isEmpty(), "No preprocessors were specified (empty input)");
|
||||
this.preProcessors = new ArrayList<>(preProcessors);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String preProcess(String token) {
|
||||
String s = token;
|
||||
for(TokenPreProcess tpp : preProcessors){
|
||||
s = tpp.preProcess(s);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
}
|
|
@ -19,15 +19,13 @@ package org.deeplearning4j.text.tokenization.tokenizer;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.LowCasePreProcessor;
|
||||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.BertWordPiecePreProcessor;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.resources.Resources;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.File;
|
||||
|
@ -61,6 +59,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
|||
log.info("Position: [" + position + "], token1: '" + tok1 + "', token 2: '" + tok2 + "'");
|
||||
position++;
|
||||
assertEquals(tok1, tok2);
|
||||
|
||||
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
|
||||
assertEquals(toTokenize, s2);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -76,7 +77,6 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
@Ignore("AB 2019/05/24 - Disabled until dev branch merged - see issue #7657")
|
||||
public void testBertWordPieceTokenizer3() throws Exception {
|
||||
String toTokenize = "Donaudampfschifffahrtskapitänsmützeninnenfuttersaum";
|
||||
TokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
||||
|
@ -86,6 +86,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
|||
final List<String> expected = Arrays.asList("Donau", "##dam", "##pf", "##schiff", "##fahrt", "##skap", "##itä", "##ns", "##m", "##ützen", "##innen", "##fu", "##tter", "##sa", "##um");
|
||||
assertEquals(expected, tokenizer.getTokens());
|
||||
assertEquals(expected, tokenizer2.getTokens());
|
||||
|
||||
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
|
||||
assertEquals(toTokenize, s2);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -98,6 +101,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
|||
final List<String> expected = Arrays.asList("I", "saw", "a", "girl", "with", "a", "tele", "##scope", ".");
|
||||
assertEquals(expected, tokenizer.getTokens());
|
||||
assertEquals(expected, tokenizer2.getTokens());
|
||||
|
||||
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
|
||||
assertEquals(toTokenize, s2);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -112,6 +118,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
|||
final List<String> expected = Arrays.asList("Donau", "##dam", "##pf", "##schiff", "##fahrt", "##s", "Kapitän", "##sm", "##ützen", "##innen", "##fu", "##tter", "##sa", "##um");
|
||||
assertEquals(expected, tokenizer.getTokens());
|
||||
assertEquals(expected, tokenizer2.getTokens());
|
||||
|
||||
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
|
||||
assertEquals(toTokenize, s2);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -125,6 +134,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
|||
final List<String> expected = Arrays.asList("i", "saw", "a", "girl", "with", "a", "tele", "##scope", ".");
|
||||
assertEquals(expected, tokenizer.getTokens());
|
||||
assertEquals(expected, tokenizer2.getTokens());
|
||||
|
||||
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
|
||||
assertEquals(toTokenize.toLowerCase(), s2);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -138,6 +150,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
|||
final List<String> expected = Arrays.asList("i", "saw", "a", "girl", "with", "a", "tele", "##scope", ".");
|
||||
assertEquals(expected, tokenizer.getTokens());
|
||||
assertEquals(expected, tokenizer2.getTokens());
|
||||
|
||||
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
|
||||
assertEquals(toTokenize.toLowerCase(), s2);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -188,4 +203,21 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testBertWordPieceTokenizer10() throws Exception {
|
||||
File f = Resources.asFile("deeplearning4j-nlp/bert/uncased_L-12_H-768_A-12/vocab.txt");
|
||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(f, true, true, StandardCharsets.UTF_8);
|
||||
|
||||
String s = "This is a sentence with Multiple Cases For Words. It should be coverted to Lower Case here.";
|
||||
|
||||
Tokenizer tokenizer = t.create(s);
|
||||
List<String> list = tokenizer.getTokens();
|
||||
System.out.println(list);
|
||||
|
||||
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(list);
|
||||
String exp = s.toLowerCase();
|
||||
assertEquals(exp, s2);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue