cleaned up bert iterator tests (#110)

Signed-off-by: eraly <susan.eraly@gmail.com>
master
Susan Eraly 2019-12-04 18:24:37 -08:00 committed by Alex Black
parent 9592072cef
commit 63ed202057
1 changed files with 338 additions and 343 deletions

View File

@ -17,11 +17,13 @@
package org.deeplearning4j.iterator; package org.deeplearning4j.iterator;
import lombok.Getter;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker; import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
import org.deeplearning4j.iterator.provider.CollectionLabeledPairSentenceProvider;
import org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider;
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
import org.junit.Test; import org.junit.Test;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
@ -42,8 +44,12 @@ import static org.junit.Assert.*;
public class TestBertIterator extends BaseDL4JTest { public class TestBertIterator extends BaseDL4JTest {
private File pathToVocab = Resources.asFile("other/vocab.txt"); private static File pathToVocab = Resources.asFile("other/vocab.txt");
private static Charset c = StandardCharsets.UTF_8; private static Charset c = StandardCharsets.UTF_8;
private static String shortSentence = "I saw a girl with a telescope.";
private static String longSentence = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
private static String sentenceA = "Goodnight noises everywhere";
private static String sentenceB = "Goodnight moon";
public TestBertIterator() throws IOException { public TestBertIterator() throws IOException {
} }
@ -51,20 +57,15 @@ public class TestBertIterator extends BaseDL4JTest {
@Test(timeout = 20000L) @Test(timeout = 20000L)
public void testBertSequenceClassification() throws Exception { public void testBertSequenceClassification() throws Exception {
String toTokenize1 = "I saw a girl with a telescope."; int minibatchSize = 2;
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; TestSentenceHelper testHelper = new TestSentenceHelper();
List<String> forInference = new ArrayList<>();
forInference.add(toTokenize1);
forInference.add(toTokenize2);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
BertIterator b = BertIterator.builder() BertIterator b = BertIterator.builder()
.tokenizer(t) .tokenizer(testHelper.getTokenizer())
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
.minibatchSize(2) .minibatchSize(minibatchSize)
.sentenceProvider(new TestSentenceProvider()) .sentenceProvider(testHelper.getSentenceProvider())
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
.vocabMap(t.getVocab()) .vocabMap(testHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION) .task(BertIterator.Task.SEQ_CLASSIFICATION)
.build(); .build();
@ -73,82 +74,77 @@ public class TestBertIterator extends BaseDL4JTest {
System.out.println(mds.getFeatures(0)); System.out.println(mds.getFeatures(0));
System.out.println(mds.getFeaturesMaskArray(0)); System.out.println(mds.getFeaturesMaskArray(0));
INDArray expF = Nd4j.create(DataType.INT, 1, 16);
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); Map<String, Integer> m = testHelper.getTokenizer().getVocab();
List<String> tokens = t.create(toTokenize1).getTokens(); for (int i = 0; i < minibatchSize; i++) {
Map<String, Integer> m = t.getVocab(); INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16);
for (int i = 0; i < tokens.size(); i++) { INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16);
int idx = m.get(tokens.get(i)); List<String> tokens = testHelper.getTokenizedSentences().get(i);
expEx0.putScalar(0, i, idx); System.out.println(tokens);
expM0.putScalar(0, i, 1); for (int j = 0; j < tokens.size(); j++) {
} String token = tokens.get(j);
if (!m.containsKey(token)) {
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); throw new IllegalStateException("Unknown token: \"" + token + "\"");
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); }
List<String> tokens2 = t.create(toTokenize2).getTokens(); int idx = m.get(token);
for (int i = 0; i < tokens2.size(); i++) { expFTemp.putScalar(0, j, idx);
String token = tokens2.get(i); expMTemp.putScalar(0, j, 1);
if (!m.containsKey(token)) { }
throw new IllegalStateException("Unknown token: \"" + token + "\""); if (i == 0) {
expF = expFTemp.dup();
expM = expMTemp.dup();
} else {
expF = Nd4j.vstack(expF, expFTemp);
expM = Nd4j.vstack(expM, expMTemp);
} }
int idx = m.get(token);
expEx1.putScalar(0, i, idx);
expM1.putScalar(0, i, 1);
} }
INDArray expF = Nd4j.vstack(expEx0, expEx1);
INDArray expM = Nd4j.vstack(expM0, expM1);
assertEquals(expF, mds.getFeatures(0)); assertEquals(expF, mds.getFeatures(0));
assertEquals(expM, mds.getFeaturesMaskArray(0)); assertEquals(expM, mds.getFeaturesMaskArray(0));
assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]); assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]);
assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]);
b.next(); //pop the third element
assertFalse(b.hasNext()); assertFalse(b.hasNext());
b.reset(); b.reset();
assertTrue(b.hasNext()); assertTrue(b.hasNext());
forInference.set(0, toTokenize2);
//Same thing, but with segment ID also //Same thing, but with segment ID also
b = BertIterator.builder() b = BertIterator.builder()
.tokenizer(t) .tokenizer(testHelper.getTokenizer())
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
.minibatchSize(2) .minibatchSize(minibatchSize)
.sentenceProvider(new TestSentenceProvider()) .sentenceProvider(testHelper.getSentenceProvider())
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab()) .vocabMap(testHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION) .task(BertIterator.Task.SEQ_CLASSIFICATION)
.build(); .build();
mds = b.next(); mds = b.next();
assertEquals(2, mds.getFeatures().length); assertEquals(2, mds.getFeatures().length);
//assertEquals(2, mds.getFeaturesMaskArrays().length); second element is null...
assertEquals(2, b.featurizeSentences(forInference).getFirst().length);
//Segment ID should be all 0s for single segment task //Segment ID should be all 0s for single segment task
INDArray segmentId = expM.like(); INDArray segmentId = expM.like();
assertEquals(segmentId, mds.getFeatures(1)); assertEquals(segmentId, mds.getFeatures(1));
assertEquals(segmentId, b.featurizeSentences(forInference).getFirst()[1]); assertEquals(segmentId, b.featurizeSentences(testHelper.getSentences()).getFirst()[1]);
} }
@Test(timeout = 20000L) @Test(timeout = 20000L)
public void testBertUnsupervised() throws Exception { public void testBertUnsupervised() throws Exception {
int minibatchSize = 2;
TestSentenceHelper testHelper = new TestSentenceHelper();
//Task 1: Unsupervised //Task 1: Unsupervised
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
BertIterator b = BertIterator.builder() BertIterator b = BertIterator.builder()
.tokenizer(t) .tokenizer(testHelper.getTokenizer())
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
.minibatchSize(2) .minibatchSize(minibatchSize)
.sentenceProvider(new TestSentenceProvider()) .sentenceProvider(testHelper.getSentenceProvider())
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
.vocabMap(t.getVocab()) .vocabMap(testHelper.getTokenizer().getVocab())
.task(BertIterator.Task.UNSUPERVISED) .task(BertIterator.Task.UNSUPERVISED)
.masker(new BertMaskedLMMasker(new Random(12345), 0.2, 0.5, 0.5)) .masker(new BertMaskedLMMasker(new Random(12345), 0.2, 0.5, 0.5))
.unsupervisedLabelFormat(BertIterator.UnsupervisedLabelFormat.RANK2_IDX) .unsupervisedLabelFormat(BertIterator.UnsupervisedLabelFormat.RANK2_IDX)
.maskToken("[MASK]") .maskToken("[MASK]")
.build(); .build();
System.out.println("Mask token index: " + t.getVocab().get("[MASK]")); System.out.println("Mask token index: " + testHelper.getTokenizer().getVocab().get("[MASK]"));
MultiDataSet mds = b.next(); MultiDataSet mds = b.next();
System.out.println(mds.getFeatures(0)); System.out.println(mds.getFeatures(0));
@ -156,7 +152,6 @@ public class TestBertIterator extends BaseDL4JTest {
System.out.println(mds.getLabels(0)); System.out.println(mds.getLabels(0));
System.out.println(mds.getLabelsMaskArray(0)); System.out.println(mds.getLabelsMaskArray(0));
b.next(); //pop the third element
assertFalse(b.hasNext()); assertFalse(b.hasNext());
b.reset(); b.reset();
assertTrue(b.hasNext()); assertTrue(b.hasNext());
@ -164,40 +159,34 @@ public class TestBertIterator extends BaseDL4JTest {
@Test(timeout = 20000L) @Test(timeout = 20000L)
public void testLengthHandling() throws Exception { public void testLengthHandling() throws Exception {
String toTokenize1 = "I saw a girl with a telescope."; int minibatchSize = 2;
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; TestSentenceHelper testHelper = new TestSentenceHelper();
List<String> forInference = new ArrayList<>(); INDArray expF = Nd4j.create(DataType.INT, 1, 16);
forInference.add(toTokenize1); INDArray expM = Nd4j.create(DataType.INT, 1, 16);
forInference.add(toTokenize2); Map<String, Integer> m = testHelper.getTokenizer().getVocab();
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); for (int i = 0; i < minibatchSize; i++) {
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); List<String> tokens = testHelper.getTokenizedSentences().get(i);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens = t.create(toTokenize1).getTokens(); INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16);
System.out.println(tokens); System.out.println(tokens);
Map<String, Integer> m = t.getVocab(); for (int j = 0; j < tokens.size(); j++) {
for (int i = 0; i < tokens.size(); i++) { String token = tokens.get(j);
int idx = m.get(tokens.get(i)); if (!m.containsKey(token)) {
expEx0.putScalar(0, i, idx); throw new IllegalStateException("Unknown token: \"" + token + "\"");
expM0.putScalar(0, i, 1); }
} int idx = m.get(token);
expFTemp.putScalar(0, j, idx);
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); expMTemp.putScalar(0, j, 1);
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); }
List<String> tokens2 = t.create(toTokenize2).getTokens(); if (i == 0) {
System.out.println(tokens2); expF = expFTemp.dup();
for (int i = 0; i < tokens2.size(); i++) { expM = expMTemp.dup();
String token = tokens2.get(i); } else {
if (!m.containsKey(token)) { expF = Nd4j.vstack(expF, expFTemp);
throw new IllegalStateException("Unknown token: \"" + token + "\""); expM = Nd4j.vstack(expM, expMTemp);
} }
int idx = m.get(token);
expEx1.putScalar(0, i, idx);
expM1.putScalar(0, i, 1);
} }
INDArray expF = Nd4j.vstack(expEx0, expEx1);
INDArray expM = Nd4j.vstack(expM0, expM1);
//-------------------------------------------------------------- //--------------------------------------------------------------
//Fixed length: clip or pad - already tested in other tests //Fixed length: clip or pad - already tested in other tests
@ -205,12 +194,12 @@ public class TestBertIterator extends BaseDL4JTest {
//Any length: as long as we need to fit longest sequence //Any length: as long as we need to fit longest sequence
BertIterator b = BertIterator.builder() BertIterator b = BertIterator.builder()
.tokenizer(t) .tokenizer(testHelper.getTokenizer())
.lengthHandling(BertIterator.LengthHandling.ANY_LENGTH, -1) .lengthHandling(BertIterator.LengthHandling.ANY_LENGTH, -1)
.minibatchSize(2) .minibatchSize(minibatchSize)
.sentenceProvider(new TestSentenceProvider()) .sentenceProvider(testHelper.getSentenceProvider())
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
.vocabMap(t.getVocab()) .vocabMap(testHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION) .task(BertIterator.Task.SEQ_CLASSIFICATION)
.build(); .build();
MultiDataSet mds = b.next(); MultiDataSet mds = b.next();
@ -219,20 +208,19 @@ public class TestBertIterator extends BaseDL4JTest {
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeatures(0)); assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeatures(0));
assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeaturesMaskArray(0)); assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeaturesMaskArray(0));
assertEquals(mds.getFeatures(0), b.featurizeSentences(forInference).getFirst()[0]); assertEquals(mds.getFeatures(0), b.featurizeSentences(testHelper.getSentences()).getFirst()[0]);
assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]); assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(testHelper.getSentences()).getSecond()[0]);
//Clip only: clip to maximum, but don't pad if less //Clip only: clip to maximum, but don't pad if less
b = BertIterator.builder() b = BertIterator.builder()
.tokenizer(t) .tokenizer(testHelper.getTokenizer())
.lengthHandling(BertIterator.LengthHandling.CLIP_ONLY, 20) .lengthHandling(BertIterator.LengthHandling.CLIP_ONLY, 20)
.minibatchSize(2) .minibatchSize(minibatchSize)
.sentenceProvider(new TestSentenceProvider()) .sentenceProvider(testHelper.getSentenceProvider())
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
.vocabMap(t.getVocab()) .vocabMap(testHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION) .task(BertIterator.Task.SEQ_CLASSIFICATION)
.build(); .build();
mds = b.next();
expShape = new long[]{2, 14}; expShape = new long[]{2, 14};
assertArrayEquals(expShape, mds.getFeatures(0).shape()); assertArrayEquals(expShape, mds.getFeatures(0).shape());
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
@ -241,54 +229,38 @@ public class TestBertIterator extends BaseDL4JTest {
@Test(timeout = 20000L) @Test(timeout = 20000L)
public void testMinibatchPadding() throws Exception { public void testMinibatchPadding() throws Exception {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
String toTokenize1 = "I saw a girl with a telescope."; int minibatchSize = 3;
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; TestSentenceHelper testHelper = new TestSentenceHelper(minibatchSize);
String toTokenize3 = "Goodnight noises everywhere";
List<String> forInference = new ArrayList<>();
forInference.add(toTokenize1);
forInference.add(toTokenize2);
forInference.add(toTokenize3);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens = t.create(toTokenize1).getTokens();
Map<String, Integer> m = t.getVocab();
for (int i = 0; i < tokens.size(); i++) {
int idx = m.get(tokens.get(i));
expEx0.putScalar(0, i, idx);
expM0.putScalar(0, i, 1);
}
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens2 = t.create(toTokenize2).getTokens();
for (int i = 0; i < tokens2.size(); i++) {
String token = tokens2.get(i);
if (!m.containsKey(token)) {
throw new IllegalStateException("Unknown token: \"" + token + "\"");
}
int idx = m.get(token);
expEx1.putScalar(0, i, idx);
expM1.putScalar(0, i, 1);
}
INDArray expEx3 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM3 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens3 = t.create(toTokenize3).getTokens();
for (int i = 0; i < tokens3.size(); i++) {
String token = tokens3.get(i);
if (!m.containsKey(token)) {
throw new IllegalStateException("Unknown token: \"" + token + "\"");
}
int idx = m.get(token);
expEx3.putScalar(0, i, idx);
expM3.putScalar(0, i, 1);
}
INDArray zeros = Nd4j.create(DataType.INT, 1, 16); INDArray zeros = Nd4j.create(DataType.INT, 1, 16);
INDArray expF = Nd4j.vstack(expEx0, expEx1, expEx3, zeros); INDArray expF = Nd4j.create(DataType.INT, 1, 16);
INDArray expM = Nd4j.vstack(expM0, expM1, expM3, zeros); INDArray expM = Nd4j.create(DataType.INT, 1, 16);
INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {1, 0}, {0, 0}}); Map<String, Integer> m = testHelper.getTokenizer().getVocab();
for (int i = 0; i < minibatchSize; i++) {
List<String> tokens = testHelper.getTokenizedSentences().get(i);
INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16);
INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16);
System.out.println(tokens);
for (int j = 0; j < tokens.size(); j++) {
String token = tokens.get(j);
if (!m.containsKey(token)) {
throw new IllegalStateException("Unknown token: \"" + token + "\"");
}
int idx = m.get(token);
expFTemp.putScalar(0, j, idx);
expMTemp.putScalar(0, j, 1);
}
if (i == 0) {
expF = expFTemp.dup();
expM = expMTemp.dup();
} else {
expF = Nd4j.vstack(expF.dup(), expFTemp);
expM = Nd4j.vstack(expM.dup(), expMTemp);
}
}
expF = Nd4j.vstack(expF, zeros);
expM = Nd4j.vstack(expM, zeros);
INDArray expL = Nd4j.createFromArray(new float[][]{{0, 1}, {1, 0}, {0, 1}, {0, 0}});
INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1); INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1);
expLM.putScalar(0, 0, 1); expLM.putScalar(0, 0, 1);
expLM.putScalar(1, 0, 1); expLM.putScalar(1, 0, 1);
@ -297,13 +269,13 @@ public class TestBertIterator extends BaseDL4JTest {
//-------------------------------------------------------------- //--------------------------------------------------------------
BertIterator b = BertIterator.builder() BertIterator b = BertIterator.builder()
.tokenizer(t) .tokenizer(testHelper.getTokenizer())
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
.minibatchSize(4) .minibatchSize(minibatchSize + 1)
.padMinibatches(true) .padMinibatches(true)
.sentenceProvider(new TestSentenceProvider()) .sentenceProvider(testHelper.getSentenceProvider())
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab()) .vocabMap(testHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION) .task(BertIterator.Task.SEQ_CLASSIFICATION)
.build(); .build();
@ -323,170 +295,175 @@ public class TestBertIterator extends BaseDL4JTest {
assertEquals(expL, mds.getLabels(0)); assertEquals(expL, mds.getLabels(0));
assertEquals(expLM, mds.getLabelsMaskArray(0)); assertEquals(expLM, mds.getLabelsMaskArray(0));
assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]); assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]);
assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]);
} }
/*
Checks that a mds from a pair sentence is equal to hstack'd mds from the left side and right side of the pair
Checks different lengths for max length to check popping and padding
*/
@Test @Test
public void testSentencePairsSingle() throws IOException { public void testSentencePairsSingle() throws IOException {
String shortSent = "I saw a girl with a telescope.";
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
boolean prependAppend; boolean prependAppend;
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); int numOfSentences;
int shortL = t.create(shortSent).countTokens();
int longL = t.create(longSent).countTokens(); TestSentenceHelper testHelper = new TestSentenceHelper();
int shortL = testHelper.getShortestL();
int longL = testHelper.getLongestL();
Triple<MultiDataSet, MultiDataSet, MultiDataSet> multiDataSetTriple; Triple<MultiDataSet, MultiDataSet, MultiDataSet> multiDataSetTriple;
MultiDataSet shortLongPair, shortSentence, longSentence; MultiDataSet fromPair, leftSide, rightSide;
// check for pair max length exactly equal to sum of lengths - pop neither no padding // check for pair max length exactly equal to sum of lengths - pop neither no padding
// should be the same as hstack with segment ids 1 for second sentence // should be the same as hstack with segment ids 1 for second sentence
prependAppend = true; prependAppend = true;
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend); numOfSentences = 1;
shortLongPair = multiDataSetTriple.getFirst(); multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend, numOfSentences);
shortSentence = multiDataSetTriple.getSecond(); fromPair = multiDataSetTriple.getFirst();
longSentence = multiDataSetTriple.getThird(); leftSide = multiDataSetTriple.getSecond();
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); rightSide = multiDataSetTriple.getThird();
longSentence.getFeatures(1).addi(1); assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); rightSide.getFeatures(1).addi(1); //add 1 for right side segment ids
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));
//check for pair max length greater than sum of lengths - pop neither with padding //check for pair max length greater than sum of lengths - pop neither with padding
// features should be the same as hstack of shorter and longer padded with prepend/append // features should be the same as hstack of shorter and longer padded with prepend/append
// segment id should 1 only in the longer for part of the length of the sentence // segment id should 1 only in the longer for part of the length of the sentence
prependAppend = true; prependAppend = true;
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend); numOfSentences = 1;
shortLongPair = multiDataSetTriple.getFirst(); multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend, numOfSentences);
shortSentence = multiDataSetTriple.getSecond(); fromPair = multiDataSetTriple.getFirst();
longSentence = multiDataSetTriple.getThird(); leftSide = multiDataSetTriple.getSecond();
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); rightSide = multiDataSetTriple.getThird();
longSentence.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); rightSide.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));
//check for pair max length less than shorter sentence - pop both //check for pair max length less than shorter sentence - pop both
//should be the same as hstack with segment ids 1 for second sentence if no prepend/append //should be the same as hstack with segment ids 1 for second sentence if no prepend/append
int maxL = shortL - 2; int maxL = 5;//checking odd
numOfSentences = 3;
prependAppend = false; prependAppend = false;
multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend); multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend, numOfSentences);
shortLongPair = multiDataSetTriple.getFirst(); fromPair = multiDataSetTriple.getFirst();
shortSentence = multiDataSetTriple.getSecond(); leftSide = multiDataSetTriple.getSecond();
longSentence = multiDataSetTriple.getThird(); rightSide = multiDataSetTriple.getThird();
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
longSentence.getFeatures(1).addi(1); rightSide.getFeatures(1).addi(1);
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));
} }
/*
Same idea as previous test - construct mds from bert iterator with sep sentences and check against one with pairs
Checks various max lengths
Has sentences of varying lengths
*/
@Test @Test
public void testSentencePairsUnequalLengths() throws IOException { public void testSentencePairsUnequalLengths() throws IOException {
//check for pop only longer (i.e between longer and longer + shorter), first row pop from second sentence, next row pop from first sentence, nothing to pop in the third row
//should be identical to hstack if there is no append, prepend int minibatchSize = 4;
//batch size is 2 int numOfSentencesinIter = 3;
int mbS = 4;
String shortSent = "I saw a girl with a telescope."; TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(numOfSentencesinIter);
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; int shortL = testPairHelper.getShortL();
String sent1 = "Goodnight noises everywhere"; //shorter than shortSent - no popping int longL = testPairHelper.getLongL();
String sent2 = "Goodnight moon"; //shorter than shortSent - no popping int sent1L = testPairHelper.getSentenceALen();
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); int sent2L = testPairHelper.getSentenceBLen();
int shortL = t.create(shortSent).countTokens();
int longL = t.create(longSent).countTokens(); System.out.println("Sentence Pairs, Left");
int sent1L = t.create(sent1).countTokens(); System.out.println(testPairHelper.getSentencesLeft());
int sent2L = t.create(sent2).countTokens(); System.out.println("Sentence Pairs, Right");
//won't check 2*shortL + 1 because this will always pop on the left System.out.println(testPairHelper.getSentencesRight());
for (int maxL = longL + shortL - 1; maxL > 2 * shortL; maxL--) {
//anything outside this range more will need to check padding,truncation
for (int maxL = longL + shortL; maxL > 2 * shortL + 1; maxL--) {
System.out.println("Running for max length = " + maxL);
MultiDataSet leftMDS = BertIterator.builder() MultiDataSet leftMDS = BertIterator.builder()
.tokenizer(t) .tokenizer(testPairHelper.getTokenizer())
.minibatchSize(mbS) .minibatchSize(minibatchSize)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab()) .vocabMap(testPairHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION) .task(BertIterator.Task.SEQ_CLASSIFICATION)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either
.sentenceProvider(new TestSentenceProvider()) .sentenceProvider(new TestSentenceHelper(numOfSentencesinIter).getSentenceProvider())
.padMinibatches(true) .padMinibatches(true)
.build().next(); .build().next();
MultiDataSet rightMDS = BertIterator.builder() MultiDataSet rightMDS = BertIterator.builder()
.tokenizer(t) .tokenizer(testPairHelper.getTokenizer())
.minibatchSize(mbS) .minibatchSize(minibatchSize)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab()) .vocabMap(testPairHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION) .task(BertIterator.Task.SEQ_CLASSIFICATION)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either
.sentenceProvider(new TestSentenceProvider(true)) .sentenceProvider(new TestSentenceHelper(true, numOfSentencesinIter).getSentenceProvider())
.padMinibatches(true) .padMinibatches(true)
.build().next(); .build().next();
MultiDataSet pairMDS = BertIterator.builder() MultiDataSet pairMDS = BertIterator.builder()
.tokenizer(t) .tokenizer(testPairHelper.getTokenizer())
.minibatchSize(mbS) .minibatchSize(minibatchSize)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab()) .vocabMap(testPairHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION) .task(BertIterator.Task.SEQ_CLASSIFICATION)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) //random big num guaranteed to be longer than either .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL)
.sentencePairProvider(new TestSentencePairProvider()) .sentencePairProvider(testPairHelper.getPairSentenceProvider())
.padMinibatches(true) .padMinibatches(true)
.build().next(); .build().next();
//Left sentences here are {{shortSent},
// {longSent},
// {Sent1}}
//Right sentences here are {{longSent},
// {shortSent},
// {Sent2}}
//The sentence pairs here are {{shortSent,longSent},
// {longSent,shortSent}
// {Sent1, Sent2}}
//CHECK FEATURES //CHECK FEATURES
INDArray combinedFeat = Nd4j.create(DataType.INT,mbS,maxL); INDArray combinedFeat = Nd4j.create(DataType.INT, minibatchSize, maxL);
//left side //left side
INDArray leftFeatures = leftMDS.getFeatures(0); INDArray leftFeatures = leftMDS.getFeatures(0);
INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL)); INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL));
INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL)); INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL));
INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0,sent1L)); INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0, sent1L));
//right side //right side
INDArray rightFeatures = rightMDS.getFeatures(0); INDArray rightFeatures = rightMDS.getFeatures(0);
INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL)); INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL));
INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL)); INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL));
INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0,sent2L)); INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0, sent2L));
//expected pair //expected pair
combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat,topRSentFeat)); combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat, topRSentFeat));
combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat,midRSentFeat)); combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat, midRSentFeat));
combinedFeat.getRow(2).get(NDArrayIndex.interval(0,sent1L+sent2L)).addi(Nd4j.hstack(bottomLSentFeat,bottomRSentFeat)); combinedFeat.getRow(2).get(NDArrayIndex.interval(0, sent1L + sent2L)).addi(Nd4j.hstack(bottomLSentFeat, bottomRSentFeat));
assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]); assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]);
assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape()); assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape());
assertEquals(combinedFeat, pairMDS.getFeatures(0)); assertEquals(combinedFeat, pairMDS.getFeatures(0));
//CHECK SEGMENT ID //CHECK SEGMENT ID
INDArray combinedFetSeg = Nd4j.create(DataType.INT, mbS, maxL); INDArray combinedFetSeg = Nd4j.create(DataType.INT, minibatchSize, maxL);
combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1); combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1);
combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1); combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1);
combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L+sent2L)).addi(1); combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L + sent2L)).addi(1);
assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape()); assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape());
assertEquals(maxL, combinedFetSeg.shape()[1]); assertEquals(maxL, combinedFetSeg.shape()[1]);
assertEquals(combinedFetSeg, pairMDS.getFeatures(1)); assertEquals(combinedFetSeg, pairMDS.getFeatures(1));
testPairHelper.getPairSentenceProvider().reset();
} }
} }
@Test @Test
public void testSentencePairFeaturizer() throws IOException { public void testSentencePairFeaturizer() throws IOException {
String shortSent = "I saw a girl with a telescope."; int minibatchSize = 2;
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(minibatchSize);
List<Pair<String, String>> listSentencePair = new ArrayList<>();
listSentencePair.add(new Pair<>(shortSent, longSent));
listSentencePair.add(new Pair<>(longSent, shortSent));
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
BertIterator b = BertIterator.builder() BertIterator b = BertIterator.builder()
.tokenizer(t) .tokenizer(testPairHelper.getTokenizer())
.minibatchSize(2) .minibatchSize(minibatchSize)
.padMinibatches(true) .padMinibatches(true)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab()) .vocabMap(testPairHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION) .task(BertIterator.Task.SEQ_CLASSIFICATION)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128)
.sentencePairProvider(new TestSentencePairProvider()) .sentencePairProvider(testPairHelper.getPairSentenceProvider())
.prependToken("[CLS]") .prependToken("[CLS]")
.appendToken("[SEP]") .appendToken("[SEP]")
.build(); .build();
@ -494,23 +471,19 @@ public class TestBertIterator extends BaseDL4JTest {
INDArray[] featuresArr = mds.getFeatures(); INDArray[] featuresArr = mds.getFeatures();
INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays(); INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays();
Pair<INDArray[], INDArray[]> p = b.featurizeSentencePairs(listSentencePair); Pair<INDArray[], INDArray[]> p = b.featurizeSentencePairs(testPairHelper.getSentencePairs());
assertEquals(p.getFirst().length, 2); assertEquals(p.getFirst().length, 2);
assertEquals(featuresArr[0], p.getFirst()[0]); assertEquals(featuresArr[0], p.getFirst()[0]);
assertEquals(featuresArr[1], p.getFirst()[1]); assertEquals(featuresArr[1], p.getFirst()[1]);
//assertEquals(p.getSecond().length, 2);
assertEquals(featuresMaskArr[0], p.getSecond()[0]); assertEquals(featuresMaskArr[0], p.getSecond()[0]);
//assertEquals(featuresMaskArr[1], p.getSecond()[1]);
} }
/** /**
* Returns three multidatasets from bert iterator based on given max lengths and whether to prepend/append * Returns three multidatasets (one from pair of sentences and the other two from single sentence lists) from bert iterator
* with given max lengths and whether to prepend/append
* Idea is the sentence pair dataset can be constructed from the single sentence datasets * Idea is the sentence pair dataset can be constructed from the single sentence datasets
* First one is constructed from a sentence pair "I saw a girl with a telescope." & "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
* Second one is constructed from the left of the sentence pair i.e "I saw a girl with a telescope."
* Third one is constructed from the right of the sentence pair i.e "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
*/ */
private Triple<MultiDataSet, MultiDataSet, MultiDataSet> generateMultiDataSets(Triple<Integer, Integer, Integer> maxLengths, boolean prependAppend) throws IOException { private Triple<MultiDataSet, MultiDataSet, MultiDataSet> generateMultiDataSets(Triple<Integer, Integer, Integer> maxLengths, boolean prependAppend, int numSentences) throws IOException {
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
int maxforPair = maxLengths.getFirst(); int maxforPair = maxLengths.getFirst();
int maxPartOne = maxLengths.getSecond(); int maxPartOne = maxLengths.getSecond();
@ -518,133 +491,155 @@ public class TestBertIterator extends BaseDL4JTest {
BertIterator.Builder commonBuilder; BertIterator.Builder commonBuilder;
commonBuilder = BertIterator.builder() commonBuilder = BertIterator.builder()
.tokenizer(t) .tokenizer(t)
.minibatchSize(1) .minibatchSize(4)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab()) .vocabMap(t.getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION); .task(BertIterator.Task.SEQ_CLASSIFICATION);
BertIterator shortLongPairFirstIter = commonBuilder BertIterator pairIter = commonBuilder
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair)
.sentencePairProvider(new TestSentencePairProvider()) .sentencePairProvider(new TestSentencePairsHelper(numSentences).getPairSentenceProvider())
.prependToken(prependAppend ? "[CLS]" : null) .prependToken(prependAppend ? "[CLS]" : null)
.appendToken(prependAppend ? "[SEP]" : null) .appendToken(prependAppend ? "[SEP]" : null)
.build(); .build();
BertIterator shortFirstIter = commonBuilder BertIterator leftIter = commonBuilder
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne)
.sentenceProvider(new TestSentenceProvider()) .sentenceProvider(new TestSentenceHelper(numSentences).getSentenceProvider())
.prependToken(prependAppend ? "[CLS]" : null) .prependToken(prependAppend ? "[CLS]" : null)
.appendToken(prependAppend ? "[SEP]" : null) .appendToken(prependAppend ? "[SEP]" : null)
.build(); .build();
BertIterator longFirstIter = commonBuilder BertIterator rightIter = commonBuilder
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo)
.sentenceProvider(new TestSentenceProvider(true)) .sentenceProvider(new TestSentenceHelper(true, numSentences).getSentenceProvider())
.prependToken(null) .prependToken(null)
.appendToken(prependAppend ? "[SEP]" : null) .appendToken(prependAppend ? "[SEP]" : null)
.build(); .build();
return new Triple<>(shortLongPairFirstIter.next(), shortFirstIter.next(), longFirstIter.next()); return new Triple<>(pairIter.next(), leftIter.next(), rightIter.next());
} }
private static class TestSentenceProvider implements LabeledSentenceProvider { @Getter
private static class TestSentencePairsHelper {
private int pos = 0; private List<String> sentencesLeft;
private boolean invert; private List<String> sentencesRight;
private List<Pair<String, String>> sentencePairs;
private List<List<String>> tokenizedSentencesLeft;
private List<List<String>> tokenizedSentencesRight;
private List<String> labels;
private int shortL;
private int longL;
private int sentenceALen;
private int sentenceBLen;
private BertWordPieceTokenizerFactory tokenizer;
private CollectionLabeledPairSentenceProvider pairSentenceProvider;
private TestSentenceProvider() { private TestSentencePairsHelper() throws IOException {
this.invert = false; this(3);
} }
private TestSentenceProvider(boolean invert) { private TestSentencePairsHelper(int minibatchSize) throws IOException {
this.invert = invert; sentencesLeft = new ArrayList<>();
} sentencesRight = new ArrayList<>();
sentencePairs = new ArrayList<>();
@Override labels = new ArrayList<>();
public boolean hasNext() { tokenizedSentencesLeft = new ArrayList<>();
return pos < totalNumSentences(); tokenizedSentencesRight = new ArrayList<>();
} tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
sentencesLeft.add(shortSentence);
@Override sentencesRight.add(longSentence);
public Pair<String, String> nextSentence() { sentencePairs.add(new Pair<>(shortSentence, longSentence));
Preconditions.checkState(hasNext()); labels.add("positive");
if (pos == 0) { if (minibatchSize > 1) {
pos++; sentencesLeft.add(longSentence);
if (!invert) return new Pair<>("I saw a girl with a telescope.", "positive"); sentencesRight.add(shortSentence);
return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative"); sentencePairs.add(new Pair<>(longSentence, shortSentence));
} else { labels.add("negative");
if (pos == 1) { if (minibatchSize > 2) {
pos++; sentencesLeft.add(sentenceA);
if (!invert) return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative"); sentencesRight.add(sentenceB);
return new Pair<>("I saw a girl with a telescope.", "positive"); sentencePairs.add(new Pair<>(sentenceA, sentenceB));
labels.add("positive");
} }
pos++;
if (!invert)
return new Pair<>("Goodnight noises everywhere", "positive");
return new Pair<>("Goodnight moon", "positive");
} }
} for (int i = 0; i < minibatchSize; i++) {
List<String> tokensL = tokenizer.create(sentencesLeft.get(i)).getTokens();
@Override List<String> tokensR = tokenizer.create(sentencesRight.get(i)).getTokens();
public void reset() { if (i == 0) {
pos = 0; shortL = tokensL.size();
} longL = tokensR.size();
}
@Override if (i == 2) {
public int totalNumSentences() { sentenceALen = tokensL.size();
return 3; sentenceBLen = tokensR.size();
} }
tokenizedSentencesLeft.add(tokensL);
@Override tokenizedSentencesRight.add(tokensR);
public List<String> allLabels() { }
return Arrays.asList("positive", "negative"); pairSentenceProvider = new CollectionLabeledPairSentenceProvider(sentencesLeft, sentencesRight, labels, null);
}
@Override
public int numLabelClasses() {
return 2;
} }
} }
private static class TestSentencePairProvider implements LabeledPairSentenceProvider { @Getter
private static class TestSentenceHelper {
private int pos = 0; private List<String> sentences;
private List<List<String>> tokenizedSentences;
private List<String> labels;
private int shortestL = 0;
private int longestL = 0;
private BertWordPieceTokenizerFactory tokenizer;
private CollectionLabeledSentenceProvider sentenceProvider;
@Override private TestSentenceHelper() throws IOException {
public boolean hasNext() { this(false, 2);
return pos < totalNumSentences();
} }
@Override private TestSentenceHelper(int minibatchSize) throws IOException {
public Triple<String, String, String> nextSentencePair() { this(false, minibatchSize);
Preconditions.checkState(hasNext()); }
if (pos == 0) {
pos++; private TestSentenceHelper(boolean alternateOrder) throws IOException {
return new Triple<>("I saw a girl with a telescope.", "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "positive"); this(false, 3);
} else { }
if (pos == 1) {
pos++; private TestSentenceHelper(boolean alternateOrder, int minibatchSize) throws IOException {
return new Triple<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "I saw a girl with a telescope.", "negative"); sentences = new ArrayList<>();
labels = new ArrayList<>();
tokenizedSentences = new ArrayList<>();
tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
if (!alternateOrder) {
sentences.add(shortSentence);
labels.add("positive");
if (minibatchSize > 1) {
sentences.add(longSentence);
labels.add("negative");
if (minibatchSize > 2) {
sentences.add(sentenceA);
labels.add("positive");
}
}
} else {
sentences.add(longSentence);
labels.add("negative");
if (minibatchSize > 1) {
sentences.add(shortSentence);
labels.add("positive");
if (minibatchSize > 2) {
sentences.add(sentenceB);
labels.add("positive");
}
} }
pos++;
return new Triple<>("Goodnight noises everywhere", "Goodnight moon", "positive");
} }
} for (int i = 0; i < sentences.size(); i++) {
List<String> tokenizedSentence = tokenizer.create(sentences.get(i)).getTokens();
@Override if (i == 0)
public void reset() { shortestL = tokenizedSentence.size();
pos = 0; if (tokenizedSentence.size() > longestL)
} longestL = tokenizedSentence.size();
if (tokenizedSentence.size() < shortestL)
@Override shortestL = tokenizedSentence.size();
public int totalNumSentences() { tokenizedSentences.add(tokenizedSentence);
return 3; }
} sentenceProvider = new CollectionLabeledSentenceProvider(sentences, labels, null);
@Override
public List<String> allLabels() {
return Arrays.asList("positive", "negative");
}
@Override
public int numLabelClasses() {
return 2;
} }
} }