Fixes #8415, BERT iterator inference (#71)

* Convenience method for inference with BERT iterator

Signed-off-by: eraly <susan.eraly@gmail.com>

* Included preprocessing

Signed-off-by: eraly <susan.eraly@gmail.com>

* Copyright + example

Signed-off-by: eraly <susan.eraly@gmail.com>
master
Susan Eraly 2019-11-21 18:19:28 -08:00 committed by Alex Black
parent 7a90a31cfb
commit 823bd0ff88
2 changed files with 182 additions and 91 deletions

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -79,6 +80,17 @@ import java.util.Map;
* .build();
* }
* </pre>
* <br>
* <b>Example to use an instantiated iterator for inference:</b><br>
* <pre>
* {@code
* BertIterator b;
* List<String> forInference;
* Pair<INDArray[],INDArray[]> featuresAndMask = b.featurizeSentences(forInference);
* INDArray[] features = featuresAndMask.getFirst();
* INDArray[] featureMasks = featuresAndMask.getSecond();
* }
* </pre>
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br>
* <br>
* <u><b>{@link LengthHandling} configuration:</b></u><br>
@ -107,8 +119,11 @@ import java.util.Map;
public class BertIterator implements MultiDataSetIterator {
public enum Task {UNSUPERVISED, SEQ_CLASSIFICATION}
public enum LengthHandling {FIXED_LENGTH, ANY_LENGTH, CLIP_ONLY}
public enum FeatureArrays {INDICES_MASK, INDICES_MASK_SEGMENTID}
public enum UnsupervisedLabelFormat {RANK2_IDX, RANK3_NCL, RANK3_LNC}
protected Task task;
@ -116,7 +131,8 @@ public class BertIterator implements MultiDataSetIterator {
protected int maxTokens = -1;
protected int minibatchSize = 32;
protected boolean padMinibatches = false;
@Getter @Setter
@Getter
@Setter
protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null;
protected LengthHandling lengthHandling;
@ -167,9 +183,9 @@ public class BertIterator implements MultiDataSetIterator {
Preconditions.checkState(hasNext(), "No next element available");
List<Pair<String, String>> list = new ArrayList<>(num);
int count = 0;
int mbSize = 0;
if (sentenceProvider != null) {
while(sentenceProvider.hasNext() && count++ < num) {
while (sentenceProvider.hasNext() && mbSize++ < num) {
list.add(sentenceProvider.nextSentence());
}
} else {
@ -177,8 +193,87 @@ public class BertIterator implements MultiDataSetIterator {
throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
}
Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list);
List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
int outLength = outLTokenizedSentencesPair.getLeft();
Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength);
INDArray[] featureArray = featuresAndMaskArraysPair.getFirst();
INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond();
Pair<INDArray[], INDArray[]> labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength);
INDArray[] labelArray = labelsAndMaskArraysPair.getFirst();
INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond();
org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(featureArray, labelArray, featureMaskArray, labelMaskArray);
if (preProcessor != null)
preProcessor.preProcess(mds);
return mds;
}
/**
* For use during inference. Will convert a given list of sentences to features and feature masks as appropriate.
*
* @param listOnlySentences
* @return Pair of INDArrays[], first element is feature arrays and the second is the masks array
*/
public Pair<INDArray[], INDArray[]> featurizeSentences(List<String> listOnlySentences) {
List<Pair<String, String>> sentencesWithNullLabel = addDummyLabel(listOnlySentences);
Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel);
List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
int outLength = outLTokenizedSentencesPair.getLeft();
Pair<INDArray[], INDArray[]> featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength);
if (preProcessor != null) {
MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null);
preProcessor.preProcess(dummyMDS);
return new Pair<INDArray[],INDArray[]>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
}
return convertMiniBatchFeatures(tokenizedSentences, outLength);
}
private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> tokenizedSentences, int outLength) {
int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size();
int[][] outIdxs = new int[mbPadded][outLength];
int[][] outMask = new int[mbPadded][outLength];
for (int i = 0; i < tokenizedSentences.size(); i++) {
Pair<List<String>, String> p = tokenizedSentences.get(i);
List<String> t = p.getFirst();
for (int j = 0; j < outLength && j < t.size(); j++) {
Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j));
int idx = vocabMap.get(t.get(j));
outIdxs[i][j] = idx;
outMask[i][j] = 1;
}
}
//Create actual arrays. Indices, mask, and optional segment ID
INDArray outIdxsArr = Nd4j.createFromArray(outIdxs);
INDArray outMaskArr = Nd4j.createFromArray(outMask);
INDArray outSegmentIdArr;
INDArray[] f;
INDArray[] fm;
if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
//For now: always segment index 0 (only single s sequence input supported)
outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength);
f = new INDArray[]{outIdxsArr, outSegmentIdArr};
fm = new INDArray[]{outMaskArr, null};
} else {
f = new INDArray[]{outIdxsArr};
fm = new INDArray[]{outMaskArr};
}
return new Pair<>(f, fm);
}
private Pair<Integer, List<Pair<List<String>, String>>> tokenizeMiniBatch(List<Pair<String, String>> list) {
//Get and tokenize the sentences for this minibatch
List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(num);
List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(list.size());
int longestSeq = -1;
for (Pair<String, String> p : list) {
List<String> tokens = tokenizeSentence(p.getFirst());
@ -201,41 +296,14 @@ public class BertIterator implements MultiDataSetIterator {
default:
throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
}
int mb = tokenizedSentences.size();
int mbPadded = padMinibatches ? minibatchSize : mb;
int[][] outIdxs = new int[mbPadded][outLength];
int[][] outMask = new int[mbPadded][outLength];
for( int i=0; i<tokenizedSentences.size(); i++ ){
Pair<List<String>,String> p = tokenizedSentences.get(i);
List<String> t = p.getFirst();
for( int j=0; j<outLength && j<t.size(); j++ ){
Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encontered: token \"%s\" is not in vocabulary", t.get(j));
int idx = vocabMap.get(t.get(j));
outIdxs[i][j] = idx;
outMask[i][j] = 1;
}
}
//Create actual arrays. Indices, mask, and optional segment ID
INDArray outIdxsArr = Nd4j.createFromArray(outIdxs);
INDArray outMaskArr = Nd4j.createFromArray(outMask);
INDArray outSegmentIdArr;
INDArray[] f;
INDArray[] fm;
if(featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID){
//For now: always segment index 0 (only single s sequence input supported)
outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength);
f = new INDArray[]{outIdxsArr, outSegmentIdArr};
fm = new INDArray[]{outMaskArr, null};
} else {
f = new INDArray[]{outIdxsArr};
fm = new INDArray[]{outMaskArr};
return new Pair<>(outLength, tokenizedSentences);
}
private Pair<INDArray[], INDArray[]> convertMiniBatchLabels(List<Pair<List<String>, String>> tokenizedSentences, INDArray[] featureArray, int outLength) {
INDArray[] l = new INDArray[1];
INDArray[] lm;
int mbSize = tokenizedSentences.size();
int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size();
if (task == Task.SEQ_CLASSIFICATION) {
//Sequence classification task: output is 2d, one-hot, shape [minibatch, numClasses]
int numClasses;
@ -243,7 +311,7 @@ public class BertIterator implements MultiDataSetIterator {
if (sentenceProvider != null) {
numClasses = sentenceProvider.numLabelClasses();
List<String> labels = sentenceProvider.allLabels();
for(int i=0; i<mb; i++ ){
for (int i = 0; i < mbSize; i++) {
String lbl = tokenizedSentences.get(i).getRight();
classLabels[i] = labels.indexOf(lbl);
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
@ -252,14 +320,14 @@ public class BertIterator implements MultiDataSetIterator {
throw new RuntimeException();
}
l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses);
for( int i=0; i<mb; i++ ){
for (int i = 0; i < mbSize; i++) {
l[0].putScalar(i, classLabels[i], 1.0);
}
lm = null;
if(padMinibatches && mb != mbPadded){
if (padMinibatches && mbSize != mbPadded) {
INDArray a = Nd4j.zeros(DataType.FLOAT, mbPadded, 1);
lm = new INDArray[]{a};
a.get(NDArrayIndex.interval(0, mb), NDArrayIndex.all()).assign(1);
a.get(NDArrayIndex.interval(0, mbSize), NDArrayIndex.all()).assign(1);
}
} else if (task == Task.UNSUPERVISED) {
//Unsupervised, masked language model task
@ -286,7 +354,7 @@ public class BertIterator implements MultiDataSetIterator {
throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat);
}
for( int i=0; i<mb; i++ ){
for (int i = 0; i < mbSize; i++) {
List<String> tokens = tokenizedSentences.get(i).getFirst();
Pair<List<String>, boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList);
List<String> maskedTokens = p.getFirst();
@ -309,7 +377,8 @@ public class BertIterator implements MultiDataSetIterator {
//Also update previously created feature label indexes:
String newToken = maskedTokens.get(j);
int newTokenIdx = vocabMap.get(newToken);
outIdxsArr.putScalar(i,j,newTokenIdx);
//first element of features is outIdxsArr
featureArray[0].putScalar(i, j, newTokenIdx);
}
}
}
@ -319,12 +388,7 @@ public class BertIterator implements MultiDataSetIterator {
} else {
throw new IllegalStateException("Task not yet implemented: " + task);
}
org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(f, l, fm, lm);
if(preProcessor != null)
preProcessor.preProcess(mds);
return mds;
return new Pair<>(l, lm);
}
private List<String> tokenizeSentence(String sentence) {
@ -341,6 +405,16 @@ public class BertIterator implements MultiDataSetIterator {
return tokens;
}
private List<Pair<String, String>> addDummyLabel(List<String> listOnlySentences) {
List<Pair<String, String>> list = new ArrayList<>(listOnlySentences.size());
for (String s : listOnlySentences) {
list.add(new Pair<String, String>(s, null));
}
return list;
}
@Override
public boolean resetSupported() {
return true;
@ -399,6 +473,7 @@ public class BertIterator implements MultiDataSetIterator {
/**
* Specifies how the sequence length of the output data should be handled. See {@link BertIterator} for more details.
*
* @param lengthHandling Length handling
* @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH}
* @return
@ -412,6 +487,7 @@ public class BertIterator implements MultiDataSetIterator {
/**
* Minibatch size to use (number of examples to train on for each iteration)
* See also: {@link #padMinibatches}
*
* @param minibatchSize Minibatch size
*/
public Builder minibatchSize(int minibatchSize) {

View File

@ -26,7 +26,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.resources.Resources;
@ -34,10 +33,7 @@ import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.*;
import static org.junit.Assert.*;
@ -54,6 +50,9 @@ public class TestBertIterator extends BaseDL4JTest {
String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
List<String> forInference = new ArrayList<>();
forInference.add(toTokenize1);
forInference.add(toTokenize2);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
BertIterator b = BertIterator.builder()
@ -100,12 +99,15 @@ public class TestBertIterator extends BaseDL4JTest {
assertEquals(expF, mds.getFeatures(0));
assertEquals(expM, mds.getFeaturesMaskArray(0));
assertEquals(expF,b.featurizeSentences(forInference).getFirst()[0]);
assertEquals(expM,b.featurizeSentences(forInference).getSecond()[0]);
assertFalse(b.hasNext());
b.reset();
assertTrue(b.hasNext());
MultiDataSet mds2 = b.next();
forInference.set(0,toTokenize2);
//Same thing, but with segment ID also
b = BertIterator.builder()
.tokenizer(t)
@ -118,9 +120,11 @@ public class TestBertIterator extends BaseDL4JTest {
.build();
mds = b.next();
assertEquals(2, mds.getFeatures().length);
assertEquals(2,b.featurizeSentences(forInference).getFirst().length);
//Segment ID should be all 0s for single segment task
INDArray segmentId = expM.like();
assertEquals(segmentId, mds.getFeatures(1));
assertEquals(segmentId,b.featurizeSentences(forInference).getFirst()[1]);
}
@Test(timeout = 20000L)
@ -157,6 +161,9 @@ public class TestBertIterator extends BaseDL4JTest {
public void testLengthHandling() throws Exception {
String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
List<String> forInference = new ArrayList<>();
forInference.add(toTokenize1);
forInference.add(toTokenize2);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
@ -205,6 +212,8 @@ public class TestBertIterator extends BaseDL4JTest {
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeatures(0));
assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeaturesMaskArray(0));
assertEquals(mds.getFeatures(0),b.featurizeSentences(forInference).getFirst()[0]);
assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]);
//Clip only: clip to maximum, but don't pad if less
b = BertIterator.builder()
@ -227,6 +236,9 @@ public class TestBertIterator extends BaseDL4JTest {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
List<String> forInference = new ArrayList<>();
forInference.add(toTokenize1);
forInference.add(toTokenize2);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
@ -288,6 +300,9 @@ public class TestBertIterator extends BaseDL4JTest {
assertEquals(expM, mds.getFeaturesMaskArray(0));
assertEquals(expL, mds.getLabels(0));
assertEquals(expLM, mds.getLabelsMaskArray(0));
assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]);
assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
}
private static class TestSentenceProvider implements LabeledSentenceProvider {