BERT tokenization fixes (#35)
* Add composite token preprocessor Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix case issue with bert tokenization Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
cc6063402e
commit
c28372cb49
|
@ -5,6 +5,7 @@ import it.unimi.dsi.fastutil.ints.IntSet;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
||||||
|
|
||||||
import java.text.Normalizer;
|
import java.text.Normalizer;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -70,6 +71,11 @@ public class BertWordPiecePreProcessor implements TokenPreProcess {
|
||||||
if(cp == 0 || cp == REPLACEMENT_CHAR || isControlCharacter(cp) || (stripAccents && Character.getType(cp) == Character.NON_SPACING_MARK))
|
if(cp == 0 || cp == REPLACEMENT_CHAR || isControlCharacter(cp) || (stripAccents && Character.getType(cp) == Character.NON_SPACING_MARK))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
//Convert to lower case if necessary
|
||||||
|
if(lowerCase){
|
||||||
|
cp = Character.toLowerCase(cp);
|
||||||
|
}
|
||||||
|
|
||||||
//Replace whitespace chars with space
|
//Replace whitespace chars with space
|
||||||
if(isWhiteSpace(cp)) {
|
if(isWhiteSpace(cp)) {
|
||||||
sb.append(' ');
|
sb.append(' ');
|
||||||
|
@ -89,11 +95,6 @@ public class BertWordPiecePreProcessor implements TokenPreProcess {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
//Convert to lower case if necessary
|
|
||||||
if(lowerCase){
|
|
||||||
cp = Character.toLowerCase(cp);
|
|
||||||
}
|
|
||||||
|
|
||||||
//All other characters - keep
|
//All other characters - keep
|
||||||
sb.appendCodePoint(cp);
|
sb.appendCodePoint(cp);
|
||||||
}
|
}
|
||||||
|
@ -129,4 +130,27 @@ public class BertWordPiecePreProcessor implements TokenPreProcess {
|
||||||
(cp >= 0xF900 && cp <= 0xFAFF) ||
|
(cp >= 0xF900 && cp <= 0xFAFF) ||
|
||||||
(cp >= 0x2F800 && cp <= 0x2FA1F);
|
(cp >= 0x2F800 && cp <= 0x2FA1F);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reconstruct the String from tokens
|
||||||
|
* @param tokens
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static String reconstructFromTokens(List<String> tokens){
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
boolean first = true;
|
||||||
|
for(String s : tokens){
|
||||||
|
if(s.startsWith("##")){
|
||||||
|
sb.append(s.substring(2));
|
||||||
|
} else {
|
||||||
|
if(!first && !".".equals(s))
|
||||||
|
sb.append(" ");
|
||||||
|
sb.append(s);
|
||||||
|
first = false;
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
package org.deeplearning4j.text.tokenization.tokenizer.preprocessor;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CompositePreProcessor is a {@link TokenPreProcess} that applies multiple preprocessors sequentially
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class CompositePreProcessor implements TokenPreProcess {
|
||||||
|
|
||||||
|
private List<TokenPreProcess> preProcessors;
|
||||||
|
|
||||||
|
public CompositePreProcessor(@NonNull TokenPreProcess... preProcessors){
|
||||||
|
Preconditions.checkState(preProcessors.length > 0, "No preprocessors were specified (empty input)");
|
||||||
|
this.preProcessors = Arrays.asList(preProcessors);
|
||||||
|
}
|
||||||
|
|
||||||
|
public CompositePreProcessor(@NonNull Collection<? extends TokenPreProcess> preProcessors){
|
||||||
|
Preconditions.checkState(!preProcessors.isEmpty(), "No preprocessors were specified (empty input)");
|
||||||
|
this.preProcessors = new ArrayList<>(preProcessors);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String preProcess(String token) {
|
||||||
|
String s = token;
|
||||||
|
for(TokenPreProcess tpp : preProcessors){
|
||||||
|
s = tpp.preProcess(s);
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,15 +19,13 @@ package org.deeplearning4j.text.tokenization.tokenizer;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.LowCasePreProcessor;
|
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.BertWordPiecePreProcessor;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.nd4j.resources.Resources;
|
import org.nd4j.resources.Resources;
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.ByteArrayInputStream;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -61,6 +59,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
||||||
log.info("Position: [" + position + "], token1: '" + tok1 + "', token 2: '" + tok2 + "'");
|
log.info("Position: [" + position + "], token1: '" + tok1 + "', token 2: '" + tok2 + "'");
|
||||||
position++;
|
position++;
|
||||||
assertEquals(tok1, tok2);
|
assertEquals(tok1, tok2);
|
||||||
|
|
||||||
|
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
|
||||||
|
assertEquals(toTokenize, s2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,7 +77,6 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore("AB 2019/05/24 - Disabled until dev branch merged - see issue #7657")
|
|
||||||
public void testBertWordPieceTokenizer3() throws Exception {
|
public void testBertWordPieceTokenizer3() throws Exception {
|
||||||
String toTokenize = "Donaudampfschifffahrtskapitänsmützeninnenfuttersaum";
|
String toTokenize = "Donaudampfschifffahrtskapitänsmützeninnenfuttersaum";
|
||||||
TokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
TokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
||||||
|
@ -86,6 +86,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
||||||
final List<String> expected = Arrays.asList("Donau", "##dam", "##pf", "##schiff", "##fahrt", "##skap", "##itä", "##ns", "##m", "##ützen", "##innen", "##fu", "##tter", "##sa", "##um");
|
final List<String> expected = Arrays.asList("Donau", "##dam", "##pf", "##schiff", "##fahrt", "##skap", "##itä", "##ns", "##m", "##ützen", "##innen", "##fu", "##tter", "##sa", "##um");
|
||||||
assertEquals(expected, tokenizer.getTokens());
|
assertEquals(expected, tokenizer.getTokens());
|
||||||
assertEquals(expected, tokenizer2.getTokens());
|
assertEquals(expected, tokenizer2.getTokens());
|
||||||
|
|
||||||
|
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
|
||||||
|
assertEquals(toTokenize, s2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -98,6 +101,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
||||||
final List<String> expected = Arrays.asList("I", "saw", "a", "girl", "with", "a", "tele", "##scope", ".");
|
final List<String> expected = Arrays.asList("I", "saw", "a", "girl", "with", "a", "tele", "##scope", ".");
|
||||||
assertEquals(expected, tokenizer.getTokens());
|
assertEquals(expected, tokenizer.getTokens());
|
||||||
assertEquals(expected, tokenizer2.getTokens());
|
assertEquals(expected, tokenizer2.getTokens());
|
||||||
|
|
||||||
|
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
|
||||||
|
assertEquals(toTokenize, s2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -112,6 +118,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
||||||
final List<String> expected = Arrays.asList("Donau", "##dam", "##pf", "##schiff", "##fahrt", "##s", "Kapitän", "##sm", "##ützen", "##innen", "##fu", "##tter", "##sa", "##um");
|
final List<String> expected = Arrays.asList("Donau", "##dam", "##pf", "##schiff", "##fahrt", "##s", "Kapitän", "##sm", "##ützen", "##innen", "##fu", "##tter", "##sa", "##um");
|
||||||
assertEquals(expected, tokenizer.getTokens());
|
assertEquals(expected, tokenizer.getTokens());
|
||||||
assertEquals(expected, tokenizer2.getTokens());
|
assertEquals(expected, tokenizer2.getTokens());
|
||||||
|
|
||||||
|
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
|
||||||
|
assertEquals(toTokenize, s2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -125,6 +134,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
||||||
final List<String> expected = Arrays.asList("i", "saw", "a", "girl", "with", "a", "tele", "##scope", ".");
|
final List<String> expected = Arrays.asList("i", "saw", "a", "girl", "with", "a", "tele", "##scope", ".");
|
||||||
assertEquals(expected, tokenizer.getTokens());
|
assertEquals(expected, tokenizer.getTokens());
|
||||||
assertEquals(expected, tokenizer2.getTokens());
|
assertEquals(expected, tokenizer2.getTokens());
|
||||||
|
|
||||||
|
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
|
||||||
|
assertEquals(toTokenize.toLowerCase(), s2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -138,6 +150,9 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
||||||
final List<String> expected = Arrays.asList("i", "saw", "a", "girl", "with", "a", "tele", "##scope", ".");
|
final List<String> expected = Arrays.asList("i", "saw", "a", "girl", "with", "a", "tele", "##scope", ".");
|
||||||
assertEquals(expected, tokenizer.getTokens());
|
assertEquals(expected, tokenizer.getTokens());
|
||||||
assertEquals(expected, tokenizer2.getTokens());
|
assertEquals(expected, tokenizer2.getTokens());
|
||||||
|
|
||||||
|
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(tokenizer.getTokens());
|
||||||
|
assertEquals(toTokenize.toLowerCase(), s2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -188,4 +203,21 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBertWordPieceTokenizer10() throws Exception {
|
||||||
|
File f = Resources.asFile("deeplearning4j-nlp/bert/uncased_L-12_H-768_A-12/vocab.txt");
|
||||||
|
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(f, true, true, StandardCharsets.UTF_8);
|
||||||
|
|
||||||
|
String s = "This is a sentence with Multiple Cases For Words. It should be coverted to Lower Case here.";
|
||||||
|
|
||||||
|
Tokenizer tokenizer = t.create(s);
|
||||||
|
List<String> list = tokenizer.getTokens();
|
||||||
|
System.out.println(list);
|
||||||
|
|
||||||
|
String s2 = BertWordPiecePreProcessor.reconstructFromTokens(list);
|
||||||
|
String exp = s.toLowerCase();
|
||||||
|
assertEquals(exp, s2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue