parent
9592072cef
commit
63ed202057
|
@ -17,11 +17,13 @@
|
||||||
|
|
||||||
package org.deeplearning4j.iterator;
|
package org.deeplearning4j.iterator;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
|
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
|
||||||
|
import org.deeplearning4j.iterator.provider.CollectionLabeledPairSentenceProvider;
|
||||||
|
import org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
@ -42,8 +44,12 @@ import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class TestBertIterator extends BaseDL4JTest {
|
public class TestBertIterator extends BaseDL4JTest {
|
||||||
|
|
||||||
private File pathToVocab = Resources.asFile("other/vocab.txt");
|
private static File pathToVocab = Resources.asFile("other/vocab.txt");
|
||||||
private static Charset c = StandardCharsets.UTF_8;
|
private static Charset c = StandardCharsets.UTF_8;
|
||||||
|
private static String shortSentence = "I saw a girl with a telescope.";
|
||||||
|
private static String longSentence = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
||||||
|
private static String sentenceA = "Goodnight noises everywhere";
|
||||||
|
private static String sentenceB = "Goodnight moon";
|
||||||
|
|
||||||
public TestBertIterator() throws IOException {
|
public TestBertIterator() throws IOException {
|
||||||
}
|
}
|
||||||
|
@ -51,20 +57,15 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
@Test(timeout = 20000L)
|
@Test(timeout = 20000L)
|
||||||
public void testBertSequenceClassification() throws Exception {
|
public void testBertSequenceClassification() throws Exception {
|
||||||
|
|
||||||
String toTokenize1 = "I saw a girl with a telescope.";
|
int minibatchSize = 2;
|
||||||
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
TestSentenceHelper testHelper = new TestSentenceHelper();
|
||||||
List<String> forInference = new ArrayList<>();
|
|
||||||
forInference.add(toTokenize1);
|
|
||||||
forInference.add(toTokenize2);
|
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
|
||||||
|
|
||||||
BertIterator b = BertIterator.builder()
|
BertIterator b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testHelper.getTokenizer())
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
||||||
.minibatchSize(2)
|
.minibatchSize(minibatchSize)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(testHelper.getSentenceProvider())
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -73,82 +74,77 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
System.out.println(mds.getFeatures(0));
|
System.out.println(mds.getFeatures(0));
|
||||||
System.out.println(mds.getFeaturesMaskArray(0));
|
System.out.println(mds.getFeaturesMaskArray(0));
|
||||||
|
|
||||||
|
INDArray expF = Nd4j.create(DataType.INT, 1, 16);
|
||||||
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
|
INDArray expM = Nd4j.create(DataType.INT, 1, 16);
|
||||||
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
|
Map<String, Integer> m = testHelper.getTokenizer().getVocab();
|
||||||
List<String> tokens = t.create(toTokenize1).getTokens();
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
Map<String, Integer> m = t.getVocab();
|
INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16);
|
||||||
for (int i = 0; i < tokens.size(); i++) {
|
INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16);
|
||||||
int idx = m.get(tokens.get(i));
|
List<String> tokens = testHelper.getTokenizedSentences().get(i);
|
||||||
expEx0.putScalar(0, i, idx);
|
System.out.println(tokens);
|
||||||
expM0.putScalar(0, i, 1);
|
for (int j = 0; j < tokens.size(); j++) {
|
||||||
}
|
String token = tokens.get(j);
|
||||||
|
if (!m.containsKey(token)) {
|
||||||
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
|
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
||||||
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
|
}
|
||||||
List<String> tokens2 = t.create(toTokenize2).getTokens();
|
int idx = m.get(token);
|
||||||
for (int i = 0; i < tokens2.size(); i++) {
|
expFTemp.putScalar(0, j, idx);
|
||||||
String token = tokens2.get(i);
|
expMTemp.putScalar(0, j, 1);
|
||||||
if (!m.containsKey(token)) {
|
}
|
||||||
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
if (i == 0) {
|
||||||
|
expF = expFTemp.dup();
|
||||||
|
expM = expMTemp.dup();
|
||||||
|
} else {
|
||||||
|
expF = Nd4j.vstack(expF, expFTemp);
|
||||||
|
expM = Nd4j.vstack(expM, expMTemp);
|
||||||
}
|
}
|
||||||
int idx = m.get(token);
|
|
||||||
expEx1.putScalar(0, i, idx);
|
|
||||||
expM1.putScalar(0, i, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray expF = Nd4j.vstack(expEx0, expEx1);
|
|
||||||
INDArray expM = Nd4j.vstack(expM0, expM1);
|
|
||||||
|
|
||||||
assertEquals(expF, mds.getFeatures(0));
|
assertEquals(expF, mds.getFeatures(0));
|
||||||
assertEquals(expM, mds.getFeaturesMaskArray(0));
|
assertEquals(expM, mds.getFeaturesMaskArray(0));
|
||||||
assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]);
|
assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]);
|
||||||
assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
|
assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]);
|
||||||
|
|
||||||
b.next(); //pop the third element
|
|
||||||
assertFalse(b.hasNext());
|
assertFalse(b.hasNext());
|
||||||
b.reset();
|
b.reset();
|
||||||
assertTrue(b.hasNext());
|
assertTrue(b.hasNext());
|
||||||
|
|
||||||
forInference.set(0, toTokenize2);
|
|
||||||
//Same thing, but with segment ID also
|
//Same thing, but with segment ID also
|
||||||
b = BertIterator.builder()
|
b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testHelper.getTokenizer())
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
||||||
.minibatchSize(2)
|
.minibatchSize(minibatchSize)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(testHelper.getSentenceProvider())
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.build();
|
.build();
|
||||||
mds = b.next();
|
mds = b.next();
|
||||||
assertEquals(2, mds.getFeatures().length);
|
assertEquals(2, mds.getFeatures().length);
|
||||||
//assertEquals(2, mds.getFeaturesMaskArrays().length); second element is null...
|
|
||||||
assertEquals(2, b.featurizeSentences(forInference).getFirst().length);
|
|
||||||
//Segment ID should be all 0s for single segment task
|
//Segment ID should be all 0s for single segment task
|
||||||
INDArray segmentId = expM.like();
|
INDArray segmentId = expM.like();
|
||||||
assertEquals(segmentId, mds.getFeatures(1));
|
assertEquals(segmentId, mds.getFeatures(1));
|
||||||
assertEquals(segmentId, b.featurizeSentences(forInference).getFirst()[1]);
|
assertEquals(segmentId, b.featurizeSentences(testHelper.getSentences()).getFirst()[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 20000L)
|
@Test(timeout = 20000L)
|
||||||
public void testBertUnsupervised() throws Exception {
|
public void testBertUnsupervised() throws Exception {
|
||||||
|
int minibatchSize = 2;
|
||||||
|
TestSentenceHelper testHelper = new TestSentenceHelper();
|
||||||
//Task 1: Unsupervised
|
//Task 1: Unsupervised
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
|
||||||
BertIterator b = BertIterator.builder()
|
BertIterator b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testHelper.getTokenizer())
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
||||||
.minibatchSize(2)
|
.minibatchSize(minibatchSize)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(testHelper.getSentenceProvider())
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.UNSUPERVISED)
|
.task(BertIterator.Task.UNSUPERVISED)
|
||||||
.masker(new BertMaskedLMMasker(new Random(12345), 0.2, 0.5, 0.5))
|
.masker(new BertMaskedLMMasker(new Random(12345), 0.2, 0.5, 0.5))
|
||||||
.unsupervisedLabelFormat(BertIterator.UnsupervisedLabelFormat.RANK2_IDX)
|
.unsupervisedLabelFormat(BertIterator.UnsupervisedLabelFormat.RANK2_IDX)
|
||||||
.maskToken("[MASK]")
|
.maskToken("[MASK]")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
System.out.println("Mask token index: " + t.getVocab().get("[MASK]"));
|
System.out.println("Mask token index: " + testHelper.getTokenizer().getVocab().get("[MASK]"));
|
||||||
|
|
||||||
MultiDataSet mds = b.next();
|
MultiDataSet mds = b.next();
|
||||||
System.out.println(mds.getFeatures(0));
|
System.out.println(mds.getFeatures(0));
|
||||||
|
@ -156,7 +152,6 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
System.out.println(mds.getLabels(0));
|
System.out.println(mds.getLabels(0));
|
||||||
System.out.println(mds.getLabelsMaskArray(0));
|
System.out.println(mds.getLabelsMaskArray(0));
|
||||||
|
|
||||||
b.next(); //pop the third element
|
|
||||||
assertFalse(b.hasNext());
|
assertFalse(b.hasNext());
|
||||||
b.reset();
|
b.reset();
|
||||||
assertTrue(b.hasNext());
|
assertTrue(b.hasNext());
|
||||||
|
@ -164,40 +159,34 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test(timeout = 20000L)
|
@Test(timeout = 20000L)
|
||||||
public void testLengthHandling() throws Exception {
|
public void testLengthHandling() throws Exception {
|
||||||
String toTokenize1 = "I saw a girl with a telescope.";
|
int minibatchSize = 2;
|
||||||
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
TestSentenceHelper testHelper = new TestSentenceHelper();
|
||||||
List<String> forInference = new ArrayList<>();
|
INDArray expF = Nd4j.create(DataType.INT, 1, 16);
|
||||||
forInference.add(toTokenize1);
|
INDArray expM = Nd4j.create(DataType.INT, 1, 16);
|
||||||
forInference.add(toTokenize2);
|
Map<String, Integer> m = testHelper.getTokenizer().getVocab();
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
|
List<String> tokens = testHelper.getTokenizedSentences().get(i);
|
||||||
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
|
INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16);
|
||||||
List<String> tokens = t.create(toTokenize1).getTokens();
|
INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16);
|
||||||
System.out.println(tokens);
|
System.out.println(tokens);
|
||||||
Map<String, Integer> m = t.getVocab();
|
for (int j = 0; j < tokens.size(); j++) {
|
||||||
for (int i = 0; i < tokens.size(); i++) {
|
String token = tokens.get(j);
|
||||||
int idx = m.get(tokens.get(i));
|
if (!m.containsKey(token)) {
|
||||||
expEx0.putScalar(0, i, idx);
|
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
||||||
expM0.putScalar(0, i, 1);
|
}
|
||||||
}
|
int idx = m.get(token);
|
||||||
|
expFTemp.putScalar(0, j, idx);
|
||||||
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
|
expMTemp.putScalar(0, j, 1);
|
||||||
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
|
}
|
||||||
List<String> tokens2 = t.create(toTokenize2).getTokens();
|
if (i == 0) {
|
||||||
System.out.println(tokens2);
|
expF = expFTemp.dup();
|
||||||
for (int i = 0; i < tokens2.size(); i++) {
|
expM = expMTemp.dup();
|
||||||
String token = tokens2.get(i);
|
} else {
|
||||||
if (!m.containsKey(token)) {
|
expF = Nd4j.vstack(expF, expFTemp);
|
||||||
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
expM = Nd4j.vstack(expM, expMTemp);
|
||||||
}
|
}
|
||||||
int idx = m.get(token);
|
|
||||||
expEx1.putScalar(0, i, idx);
|
|
||||||
expM1.putScalar(0, i, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray expF = Nd4j.vstack(expEx0, expEx1);
|
|
||||||
INDArray expM = Nd4j.vstack(expM0, expM1);
|
|
||||||
|
|
||||||
//--------------------------------------------------------------
|
//--------------------------------------------------------------
|
||||||
|
|
||||||
//Fixed length: clip or pad - already tested in other tests
|
//Fixed length: clip or pad - already tested in other tests
|
||||||
|
@ -205,12 +194,12 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
//Any length: as long as we need to fit longest sequence
|
//Any length: as long as we need to fit longest sequence
|
||||||
|
|
||||||
BertIterator b = BertIterator.builder()
|
BertIterator b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testHelper.getTokenizer())
|
||||||
.lengthHandling(BertIterator.LengthHandling.ANY_LENGTH, -1)
|
.lengthHandling(BertIterator.LengthHandling.ANY_LENGTH, -1)
|
||||||
.minibatchSize(2)
|
.minibatchSize(minibatchSize)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(testHelper.getSentenceProvider())
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.build();
|
.build();
|
||||||
MultiDataSet mds = b.next();
|
MultiDataSet mds = b.next();
|
||||||
|
@ -219,20 +208,19 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
|
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
|
||||||
assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeatures(0));
|
assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeatures(0));
|
||||||
assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeaturesMaskArray(0));
|
assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeaturesMaskArray(0));
|
||||||
assertEquals(mds.getFeatures(0), b.featurizeSentences(forInference).getFirst()[0]);
|
assertEquals(mds.getFeatures(0), b.featurizeSentences(testHelper.getSentences()).getFirst()[0]);
|
||||||
assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]);
|
assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(testHelper.getSentences()).getSecond()[0]);
|
||||||
|
|
||||||
//Clip only: clip to maximum, but don't pad if less
|
//Clip only: clip to maximum, but don't pad if less
|
||||||
b = BertIterator.builder()
|
b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testHelper.getTokenizer())
|
||||||
.lengthHandling(BertIterator.LengthHandling.CLIP_ONLY, 20)
|
.lengthHandling(BertIterator.LengthHandling.CLIP_ONLY, 20)
|
||||||
.minibatchSize(2)
|
.minibatchSize(minibatchSize)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(testHelper.getSentenceProvider())
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.build();
|
.build();
|
||||||
mds = b.next();
|
|
||||||
expShape = new long[]{2, 14};
|
expShape = new long[]{2, 14};
|
||||||
assertArrayEquals(expShape, mds.getFeatures(0).shape());
|
assertArrayEquals(expShape, mds.getFeatures(0).shape());
|
||||||
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
|
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
|
||||||
|
@ -241,54 +229,38 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
@Test(timeout = 20000L)
|
@Test(timeout = 20000L)
|
||||||
public void testMinibatchPadding() throws Exception {
|
public void testMinibatchPadding() throws Exception {
|
||||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
String toTokenize1 = "I saw a girl with a telescope.";
|
int minibatchSize = 3;
|
||||||
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
TestSentenceHelper testHelper = new TestSentenceHelper(minibatchSize);
|
||||||
String toTokenize3 = "Goodnight noises everywhere";
|
|
||||||
List<String> forInference = new ArrayList<>();
|
|
||||||
forInference.add(toTokenize1);
|
|
||||||
forInference.add(toTokenize2);
|
|
||||||
forInference.add(toTokenize3);
|
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
|
||||||
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
List<String> tokens = t.create(toTokenize1).getTokens();
|
|
||||||
Map<String, Integer> m = t.getVocab();
|
|
||||||
for (int i = 0; i < tokens.size(); i++) {
|
|
||||||
int idx = m.get(tokens.get(i));
|
|
||||||
expEx0.putScalar(0, i, idx);
|
|
||||||
expM0.putScalar(0, i, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
List<String> tokens2 = t.create(toTokenize2).getTokens();
|
|
||||||
for (int i = 0; i < tokens2.size(); i++) {
|
|
||||||
String token = tokens2.get(i);
|
|
||||||
if (!m.containsKey(token)) {
|
|
||||||
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
|
||||||
}
|
|
||||||
int idx = m.get(token);
|
|
||||||
expEx1.putScalar(0, i, idx);
|
|
||||||
expM1.putScalar(0, i, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray expEx3 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
INDArray expM3 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
List<String> tokens3 = t.create(toTokenize3).getTokens();
|
|
||||||
for (int i = 0; i < tokens3.size(); i++) {
|
|
||||||
String token = tokens3.get(i);
|
|
||||||
if (!m.containsKey(token)) {
|
|
||||||
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
|
||||||
}
|
|
||||||
int idx = m.get(token);
|
|
||||||
expEx3.putScalar(0, i, idx);
|
|
||||||
expM3.putScalar(0, i, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray zeros = Nd4j.create(DataType.INT, 1, 16);
|
INDArray zeros = Nd4j.create(DataType.INT, 1, 16);
|
||||||
INDArray expF = Nd4j.vstack(expEx0, expEx1, expEx3, zeros);
|
INDArray expF = Nd4j.create(DataType.INT, 1, 16);
|
||||||
INDArray expM = Nd4j.vstack(expM0, expM1, expM3, zeros);
|
INDArray expM = Nd4j.create(DataType.INT, 1, 16);
|
||||||
INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {1, 0}, {0, 0}});
|
Map<String, Integer> m = testHelper.getTokenizer().getVocab();
|
||||||
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
List<String> tokens = testHelper.getTokenizedSentences().get(i);
|
||||||
|
INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16);
|
||||||
|
INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16);
|
||||||
|
System.out.println(tokens);
|
||||||
|
for (int j = 0; j < tokens.size(); j++) {
|
||||||
|
String token = tokens.get(j);
|
||||||
|
if (!m.containsKey(token)) {
|
||||||
|
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
||||||
|
}
|
||||||
|
int idx = m.get(token);
|
||||||
|
expFTemp.putScalar(0, j, idx);
|
||||||
|
expMTemp.putScalar(0, j, 1);
|
||||||
|
}
|
||||||
|
if (i == 0) {
|
||||||
|
expF = expFTemp.dup();
|
||||||
|
expM = expMTemp.dup();
|
||||||
|
} else {
|
||||||
|
expF = Nd4j.vstack(expF.dup(), expFTemp);
|
||||||
|
expM = Nd4j.vstack(expM.dup(), expMTemp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expF = Nd4j.vstack(expF, zeros);
|
||||||
|
expM = Nd4j.vstack(expM, zeros);
|
||||||
|
INDArray expL = Nd4j.createFromArray(new float[][]{{0, 1}, {1, 0}, {0, 1}, {0, 0}});
|
||||||
INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1);
|
INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1);
|
||||||
expLM.putScalar(0, 0, 1);
|
expLM.putScalar(0, 0, 1);
|
||||||
expLM.putScalar(1, 0, 1);
|
expLM.putScalar(1, 0, 1);
|
||||||
|
@ -297,13 +269,13 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
//--------------------------------------------------------------
|
//--------------------------------------------------------------
|
||||||
|
|
||||||
BertIterator b = BertIterator.builder()
|
BertIterator b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testHelper.getTokenizer())
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
||||||
.minibatchSize(4)
|
.minibatchSize(minibatchSize + 1)
|
||||||
.padMinibatches(true)
|
.padMinibatches(true)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(testHelper.getSentenceProvider())
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -323,170 +295,175 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
assertEquals(expL, mds.getLabels(0));
|
assertEquals(expL, mds.getLabels(0));
|
||||||
assertEquals(expLM, mds.getLabelsMaskArray(0));
|
assertEquals(expLM, mds.getLabelsMaskArray(0));
|
||||||
|
|
||||||
assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]);
|
assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]);
|
||||||
assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
|
assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Checks that a mds from a pair sentence is equal to hstack'd mds from the left side and right side of the pair
|
||||||
|
Checks different lengths for max length to check popping and padding
|
||||||
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testSentencePairsSingle() throws IOException {
|
public void testSentencePairsSingle() throws IOException {
|
||||||
String shortSent = "I saw a girl with a telescope.";
|
|
||||||
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
|
||||||
boolean prependAppend;
|
boolean prependAppend;
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
int numOfSentences;
|
||||||
int shortL = t.create(shortSent).countTokens();
|
|
||||||
int longL = t.create(longSent).countTokens();
|
TestSentenceHelper testHelper = new TestSentenceHelper();
|
||||||
|
int shortL = testHelper.getShortestL();
|
||||||
|
int longL = testHelper.getLongestL();
|
||||||
|
|
||||||
Triple<MultiDataSet, MultiDataSet, MultiDataSet> multiDataSetTriple;
|
Triple<MultiDataSet, MultiDataSet, MultiDataSet> multiDataSetTriple;
|
||||||
MultiDataSet shortLongPair, shortSentence, longSentence;
|
MultiDataSet fromPair, leftSide, rightSide;
|
||||||
|
|
||||||
// check for pair max length exactly equal to sum of lengths - pop neither no padding
|
// check for pair max length exactly equal to sum of lengths - pop neither no padding
|
||||||
// should be the same as hstack with segment ids 1 for second sentence
|
// should be the same as hstack with segment ids 1 for second sentence
|
||||||
prependAppend = true;
|
prependAppend = true;
|
||||||
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend);
|
numOfSentences = 1;
|
||||||
shortLongPair = multiDataSetTriple.getFirst();
|
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend, numOfSentences);
|
||||||
shortSentence = multiDataSetTriple.getSecond();
|
fromPair = multiDataSetTriple.getFirst();
|
||||||
longSentence = multiDataSetTriple.getThird();
|
leftSide = multiDataSetTriple.getSecond();
|
||||||
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
|
rightSide = multiDataSetTriple.getThird();
|
||||||
longSentence.getFeatures(1).addi(1);
|
assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
|
||||||
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
|
rightSide.getFeatures(1).addi(1); //add 1 for right side segment ids
|
||||||
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
|
assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
|
||||||
|
assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));
|
||||||
|
|
||||||
//check for pair max length greater than sum of lengths - pop neither with padding
|
//check for pair max length greater than sum of lengths - pop neither with padding
|
||||||
// features should be the same as hstack of shorter and longer padded with prepend/append
|
// features should be the same as hstack of shorter and longer padded with prepend/append
|
||||||
// segment id should 1 only in the longer for part of the length of the sentence
|
// segment id should 1 only in the longer for part of the length of the sentence
|
||||||
prependAppend = true;
|
prependAppend = true;
|
||||||
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend);
|
numOfSentences = 1;
|
||||||
shortLongPair = multiDataSetTriple.getFirst();
|
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend, numOfSentences);
|
||||||
shortSentence = multiDataSetTriple.getSecond();
|
fromPair = multiDataSetTriple.getFirst();
|
||||||
longSentence = multiDataSetTriple.getThird();
|
leftSide = multiDataSetTriple.getSecond();
|
||||||
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
|
rightSide = multiDataSetTriple.getThird();
|
||||||
longSentence.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part
|
assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
|
||||||
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
|
rightSide.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part
|
||||||
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
|
assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
|
||||||
|
assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));
|
||||||
|
|
||||||
//check for pair max length less than shorter sentence - pop both
|
//check for pair max length less than shorter sentence - pop both
|
||||||
//should be the same as hstack with segment ids 1 for second sentence if no prepend/append
|
//should be the same as hstack with segment ids 1 for second sentence if no prepend/append
|
||||||
int maxL = shortL - 2;
|
int maxL = 5;//checking odd
|
||||||
|
numOfSentences = 3;
|
||||||
prependAppend = false;
|
prependAppend = false;
|
||||||
multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend);
|
multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend, numOfSentences);
|
||||||
shortLongPair = multiDataSetTriple.getFirst();
|
fromPair = multiDataSetTriple.getFirst();
|
||||||
shortSentence = multiDataSetTriple.getSecond();
|
leftSide = multiDataSetTriple.getSecond();
|
||||||
longSentence = multiDataSetTriple.getThird();
|
rightSide = multiDataSetTriple.getThird();
|
||||||
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
|
assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
|
||||||
longSentence.getFeatures(1).addi(1);
|
rightSide.getFeatures(1).addi(1);
|
||||||
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
|
assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
|
||||||
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
|
assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Same idea as previous test - construct mds from bert iterator with sep sentences and check against one with pairs
|
||||||
|
Checks various max lengths
|
||||||
|
Has sentences of varying lengths
|
||||||
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testSentencePairsUnequalLengths() throws IOException {
|
public void testSentencePairsUnequalLengths() throws IOException {
|
||||||
//check for pop only longer (i.e between longer and longer + shorter), first row pop from second sentence, next row pop from first sentence, nothing to pop in the third row
|
|
||||||
//should be identical to hstack if there is no append, prepend
|
int minibatchSize = 4;
|
||||||
//batch size is 2
|
int numOfSentencesinIter = 3;
|
||||||
int mbS = 4;
|
|
||||||
String shortSent = "I saw a girl with a telescope.";
|
TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(numOfSentencesinIter);
|
||||||
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
int shortL = testPairHelper.getShortL();
|
||||||
String sent1 = "Goodnight noises everywhere"; //shorter than shortSent - no popping
|
int longL = testPairHelper.getLongL();
|
||||||
String sent2 = "Goodnight moon"; //shorter than shortSent - no popping
|
int sent1L = testPairHelper.getSentenceALen();
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
int sent2L = testPairHelper.getSentenceBLen();
|
||||||
int shortL = t.create(shortSent).countTokens();
|
|
||||||
int longL = t.create(longSent).countTokens();
|
System.out.println("Sentence Pairs, Left");
|
||||||
int sent1L = t.create(sent1).countTokens();
|
System.out.println(testPairHelper.getSentencesLeft());
|
||||||
int sent2L = t.create(sent2).countTokens();
|
System.out.println("Sentence Pairs, Right");
|
||||||
//won't check 2*shortL + 1 because this will always pop on the left
|
System.out.println(testPairHelper.getSentencesRight());
|
||||||
for (int maxL = longL + shortL - 1; maxL > 2 * shortL; maxL--) {
|
|
||||||
|
//anything outside this range more will need to check padding,truncation
|
||||||
|
for (int maxL = longL + shortL; maxL > 2 * shortL + 1; maxL--) {
|
||||||
|
|
||||||
|
System.out.println("Running for max length = " + maxL);
|
||||||
|
|
||||||
MultiDataSet leftMDS = BertIterator.builder()
|
MultiDataSet leftMDS = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testPairHelper.getTokenizer())
|
||||||
.minibatchSize(mbS)
|
.minibatchSize(minibatchSize)
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testPairHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(new TestSentenceHelper(numOfSentencesinIter).getSentenceProvider())
|
||||||
.padMinibatches(true)
|
.padMinibatches(true)
|
||||||
.build().next();
|
.build().next();
|
||||||
|
|
||||||
MultiDataSet rightMDS = BertIterator.builder()
|
MultiDataSet rightMDS = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testPairHelper.getTokenizer())
|
||||||
.minibatchSize(mbS)
|
.minibatchSize(minibatchSize)
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testPairHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either
|
||||||
.sentenceProvider(new TestSentenceProvider(true))
|
.sentenceProvider(new TestSentenceHelper(true, numOfSentencesinIter).getSentenceProvider())
|
||||||
.padMinibatches(true)
|
.padMinibatches(true)
|
||||||
.build().next();
|
.build().next();
|
||||||
|
|
||||||
MultiDataSet pairMDS = BertIterator.builder()
|
MultiDataSet pairMDS = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testPairHelper.getTokenizer())
|
||||||
.minibatchSize(mbS)
|
.minibatchSize(minibatchSize)
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testPairHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) //random big num guaranteed to be longer than either
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL)
|
||||||
.sentencePairProvider(new TestSentencePairProvider())
|
.sentencePairProvider(testPairHelper.getPairSentenceProvider())
|
||||||
.padMinibatches(true)
|
.padMinibatches(true)
|
||||||
.build().next();
|
.build().next();
|
||||||
|
|
||||||
//Left sentences here are {{shortSent},
|
|
||||||
// {longSent},
|
|
||||||
// {Sent1}}
|
|
||||||
//Right sentences here are {{longSent},
|
|
||||||
// {shortSent},
|
|
||||||
// {Sent2}}
|
|
||||||
//The sentence pairs here are {{shortSent,longSent},
|
|
||||||
// {longSent,shortSent}
|
|
||||||
// {Sent1, Sent2}}
|
|
||||||
|
|
||||||
//CHECK FEATURES
|
//CHECK FEATURES
|
||||||
INDArray combinedFeat = Nd4j.create(DataType.INT,mbS,maxL);
|
INDArray combinedFeat = Nd4j.create(DataType.INT, minibatchSize, maxL);
|
||||||
//left side
|
//left side
|
||||||
INDArray leftFeatures = leftMDS.getFeatures(0);
|
INDArray leftFeatures = leftMDS.getFeatures(0);
|
||||||
INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL));
|
INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL));
|
||||||
INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL));
|
INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL));
|
||||||
INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0,sent1L));
|
INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0, sent1L));
|
||||||
//right side
|
//right side
|
||||||
INDArray rightFeatures = rightMDS.getFeatures(0);
|
INDArray rightFeatures = rightMDS.getFeatures(0);
|
||||||
INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL));
|
INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL));
|
||||||
INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL));
|
INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL));
|
||||||
INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0,sent2L));
|
INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0, sent2L));
|
||||||
//expected pair
|
//expected pair
|
||||||
combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat,topRSentFeat));
|
combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat, topRSentFeat));
|
||||||
combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat,midRSentFeat));
|
combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat, midRSentFeat));
|
||||||
combinedFeat.getRow(2).get(NDArrayIndex.interval(0,sent1L+sent2L)).addi(Nd4j.hstack(bottomLSentFeat,bottomRSentFeat));
|
combinedFeat.getRow(2).get(NDArrayIndex.interval(0, sent1L + sent2L)).addi(Nd4j.hstack(bottomLSentFeat, bottomRSentFeat));
|
||||||
|
|
||||||
assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]);
|
assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]);
|
||||||
assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape());
|
assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape());
|
||||||
assertEquals(combinedFeat, pairMDS.getFeatures(0));
|
assertEquals(combinedFeat, pairMDS.getFeatures(0));
|
||||||
|
|
||||||
//CHECK SEGMENT ID
|
//CHECK SEGMENT ID
|
||||||
INDArray combinedFetSeg = Nd4j.create(DataType.INT, mbS, maxL);
|
INDArray combinedFetSeg = Nd4j.create(DataType.INT, minibatchSize, maxL);
|
||||||
combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1);
|
combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1);
|
||||||
combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1);
|
combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1);
|
||||||
combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L+sent2L)).addi(1);
|
combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L + sent2L)).addi(1);
|
||||||
assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape());
|
assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape());
|
||||||
assertEquals(maxL, combinedFetSeg.shape()[1]);
|
assertEquals(maxL, combinedFetSeg.shape()[1]);
|
||||||
assertEquals(combinedFetSeg, pairMDS.getFeatures(1));
|
assertEquals(combinedFetSeg, pairMDS.getFeatures(1));
|
||||||
|
|
||||||
|
testPairHelper.getPairSentenceProvider().reset();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSentencePairFeaturizer() throws IOException {
|
public void testSentencePairFeaturizer() throws IOException {
|
||||||
String shortSent = "I saw a girl with a telescope.";
|
int minibatchSize = 2;
|
||||||
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(minibatchSize);
|
||||||
List<Pair<String, String>> listSentencePair = new ArrayList<>();
|
|
||||||
listSentencePair.add(new Pair<>(shortSent, longSent));
|
|
||||||
listSentencePair.add(new Pair<>(longSent, shortSent));
|
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
|
||||||
BertIterator b = BertIterator.builder()
|
BertIterator b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testPairHelper.getTokenizer())
|
||||||
.minibatchSize(2)
|
.minibatchSize(minibatchSize)
|
||||||
.padMinibatches(true)
|
.padMinibatches(true)
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testPairHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128)
|
||||||
.sentencePairProvider(new TestSentencePairProvider())
|
.sentencePairProvider(testPairHelper.getPairSentenceProvider())
|
||||||
.prependToken("[CLS]")
|
.prependToken("[CLS]")
|
||||||
.appendToken("[SEP]")
|
.appendToken("[SEP]")
|
||||||
.build();
|
.build();
|
||||||
|
@ -494,23 +471,19 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
INDArray[] featuresArr = mds.getFeatures();
|
INDArray[] featuresArr = mds.getFeatures();
|
||||||
INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays();
|
INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays();
|
||||||
|
|
||||||
Pair<INDArray[], INDArray[]> p = b.featurizeSentencePairs(listSentencePair);
|
Pair<INDArray[], INDArray[]> p = b.featurizeSentencePairs(testPairHelper.getSentencePairs());
|
||||||
assertEquals(p.getFirst().length, 2);
|
assertEquals(p.getFirst().length, 2);
|
||||||
assertEquals(featuresArr[0], p.getFirst()[0]);
|
assertEquals(featuresArr[0], p.getFirst()[0]);
|
||||||
assertEquals(featuresArr[1], p.getFirst()[1]);
|
assertEquals(featuresArr[1], p.getFirst()[1]);
|
||||||
//assertEquals(p.getSecond().length, 2);
|
|
||||||
assertEquals(featuresMaskArr[0], p.getSecond()[0]);
|
assertEquals(featuresMaskArr[0], p.getSecond()[0]);
|
||||||
//assertEquals(featuresMaskArr[1], p.getSecond()[1]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns three multidatasets from bert iterator based on given max lengths and whether to prepend/append
|
* Returns three multidatasets (one from pair of sentences and the other two from single sentence lists) from bert iterator
|
||||||
|
* with given max lengths and whether to prepend/append
|
||||||
* Idea is the sentence pair dataset can be constructed from the single sentence datasets
|
* Idea is the sentence pair dataset can be constructed from the single sentence datasets
|
||||||
* First one is constructed from a sentence pair "I saw a girl with a telescope." & "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
|
|
||||||
* Second one is constructed from the left of the sentence pair i.e "I saw a girl with a telescope."
|
|
||||||
* Third one is constructed from the right of the sentence pair i.e "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
|
|
||||||
*/
|
*/
|
||||||
private Triple<MultiDataSet, MultiDataSet, MultiDataSet> generateMultiDataSets(Triple<Integer, Integer, Integer> maxLengths, boolean prependAppend) throws IOException {
|
private Triple<MultiDataSet, MultiDataSet, MultiDataSet> generateMultiDataSets(Triple<Integer, Integer, Integer> maxLengths, boolean prependAppend, int numSentences) throws IOException {
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
||||||
int maxforPair = maxLengths.getFirst();
|
int maxforPair = maxLengths.getFirst();
|
||||||
int maxPartOne = maxLengths.getSecond();
|
int maxPartOne = maxLengths.getSecond();
|
||||||
|
@ -518,133 +491,155 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
BertIterator.Builder commonBuilder;
|
BertIterator.Builder commonBuilder;
|
||||||
commonBuilder = BertIterator.builder()
|
commonBuilder = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(t)
|
||||||
.minibatchSize(1)
|
.minibatchSize(4)
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(t.getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION);
|
.task(BertIterator.Task.SEQ_CLASSIFICATION);
|
||||||
BertIterator shortLongPairFirstIter = commonBuilder
|
BertIterator pairIter = commonBuilder
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair)
|
||||||
.sentencePairProvider(new TestSentencePairProvider())
|
.sentencePairProvider(new TestSentencePairsHelper(numSentences).getPairSentenceProvider())
|
||||||
.prependToken(prependAppend ? "[CLS]" : null)
|
.prependToken(prependAppend ? "[CLS]" : null)
|
||||||
.appendToken(prependAppend ? "[SEP]" : null)
|
.appendToken(prependAppend ? "[SEP]" : null)
|
||||||
.build();
|
.build();
|
||||||
BertIterator shortFirstIter = commonBuilder
|
BertIterator leftIter = commonBuilder
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(new TestSentenceHelper(numSentences).getSentenceProvider())
|
||||||
.prependToken(prependAppend ? "[CLS]" : null)
|
.prependToken(prependAppend ? "[CLS]" : null)
|
||||||
.appendToken(prependAppend ? "[SEP]" : null)
|
.appendToken(prependAppend ? "[SEP]" : null)
|
||||||
.build();
|
.build();
|
||||||
BertIterator longFirstIter = commonBuilder
|
BertIterator rightIter = commonBuilder
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo)
|
||||||
.sentenceProvider(new TestSentenceProvider(true))
|
.sentenceProvider(new TestSentenceHelper(true, numSentences).getSentenceProvider())
|
||||||
.prependToken(null)
|
.prependToken(null)
|
||||||
.appendToken(prependAppend ? "[SEP]" : null)
|
.appendToken(prependAppend ? "[SEP]" : null)
|
||||||
.build();
|
.build();
|
||||||
return new Triple<>(shortLongPairFirstIter.next(), shortFirstIter.next(), longFirstIter.next());
|
return new Triple<>(pairIter.next(), leftIter.next(), rightIter.next());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class TestSentenceProvider implements LabeledSentenceProvider {
|
@Getter
|
||||||
|
private static class TestSentencePairsHelper {
|
||||||
|
|
||||||
private int pos = 0;
|
private List<String> sentencesLeft;
|
||||||
private boolean invert;
|
private List<String> sentencesRight;
|
||||||
|
private List<Pair<String, String>> sentencePairs;
|
||||||
|
private List<List<String>> tokenizedSentencesLeft;
|
||||||
|
private List<List<String>> tokenizedSentencesRight;
|
||||||
|
private List<String> labels;
|
||||||
|
private int shortL;
|
||||||
|
private int longL;
|
||||||
|
private int sentenceALen;
|
||||||
|
private int sentenceBLen;
|
||||||
|
private BertWordPieceTokenizerFactory tokenizer;
|
||||||
|
private CollectionLabeledPairSentenceProvider pairSentenceProvider;
|
||||||
|
|
||||||
private TestSentenceProvider() {
|
private TestSentencePairsHelper() throws IOException {
|
||||||
this.invert = false;
|
this(3);
|
||||||
}
|
}
|
||||||
|
|
||||||
private TestSentenceProvider(boolean invert) {
|
private TestSentencePairsHelper(int minibatchSize) throws IOException {
|
||||||
this.invert = invert;
|
sentencesLeft = new ArrayList<>();
|
||||||
}
|
sentencesRight = new ArrayList<>();
|
||||||
|
sentencePairs = new ArrayList<>();
|
||||||
@Override
|
labels = new ArrayList<>();
|
||||||
public boolean hasNext() {
|
tokenizedSentencesLeft = new ArrayList<>();
|
||||||
return pos < totalNumSentences();
|
tokenizedSentencesRight = new ArrayList<>();
|
||||||
}
|
tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
||||||
|
sentencesLeft.add(shortSentence);
|
||||||
@Override
|
sentencesRight.add(longSentence);
|
||||||
public Pair<String, String> nextSentence() {
|
sentencePairs.add(new Pair<>(shortSentence, longSentence));
|
||||||
Preconditions.checkState(hasNext());
|
labels.add("positive");
|
||||||
if (pos == 0) {
|
if (minibatchSize > 1) {
|
||||||
pos++;
|
sentencesLeft.add(longSentence);
|
||||||
if (!invert) return new Pair<>("I saw a girl with a telescope.", "positive");
|
sentencesRight.add(shortSentence);
|
||||||
return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative");
|
sentencePairs.add(new Pair<>(longSentence, shortSentence));
|
||||||
} else {
|
labels.add("negative");
|
||||||
if (pos == 1) {
|
if (minibatchSize > 2) {
|
||||||
pos++;
|
sentencesLeft.add(sentenceA);
|
||||||
if (!invert) return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative");
|
sentencesRight.add(sentenceB);
|
||||||
return new Pair<>("I saw a girl with a telescope.", "positive");
|
sentencePairs.add(new Pair<>(sentenceA, sentenceB));
|
||||||
|
labels.add("positive");
|
||||||
}
|
}
|
||||||
pos++;
|
|
||||||
if (!invert)
|
|
||||||
return new Pair<>("Goodnight noises everywhere", "positive");
|
|
||||||
return new Pair<>("Goodnight moon", "positive");
|
|
||||||
}
|
}
|
||||||
}
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
List<String> tokensL = tokenizer.create(sentencesLeft.get(i)).getTokens();
|
||||||
@Override
|
List<String> tokensR = tokenizer.create(sentencesRight.get(i)).getTokens();
|
||||||
public void reset() {
|
if (i == 0) {
|
||||||
pos = 0;
|
shortL = tokensL.size();
|
||||||
}
|
longL = tokensR.size();
|
||||||
|
}
|
||||||
@Override
|
if (i == 2) {
|
||||||
public int totalNumSentences() {
|
sentenceALen = tokensL.size();
|
||||||
return 3;
|
sentenceBLen = tokensR.size();
|
||||||
}
|
}
|
||||||
|
tokenizedSentencesLeft.add(tokensL);
|
||||||
@Override
|
tokenizedSentencesRight.add(tokensR);
|
||||||
public List<String> allLabels() {
|
}
|
||||||
return Arrays.asList("positive", "negative");
|
pairSentenceProvider = new CollectionLabeledPairSentenceProvider(sentencesLeft, sentencesRight, labels, null);
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numLabelClasses() {
|
|
||||||
return 2;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class TestSentencePairProvider implements LabeledPairSentenceProvider {
|
@Getter
|
||||||
|
private static class TestSentenceHelper {
|
||||||
|
|
||||||
private int pos = 0;
|
private List<String> sentences;
|
||||||
|
private List<List<String>> tokenizedSentences;
|
||||||
|
private List<String> labels;
|
||||||
|
private int shortestL = 0;
|
||||||
|
private int longestL = 0;
|
||||||
|
private BertWordPieceTokenizerFactory tokenizer;
|
||||||
|
private CollectionLabeledSentenceProvider sentenceProvider;
|
||||||
|
|
||||||
@Override
|
private TestSentenceHelper() throws IOException {
|
||||||
public boolean hasNext() {
|
this(false, 2);
|
||||||
return pos < totalNumSentences();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private TestSentenceHelper(int minibatchSize) throws IOException {
|
||||||
public Triple<String, String, String> nextSentencePair() {
|
this(false, minibatchSize);
|
||||||
Preconditions.checkState(hasNext());
|
}
|
||||||
if (pos == 0) {
|
|
||||||
pos++;
|
private TestSentenceHelper(boolean alternateOrder) throws IOException {
|
||||||
return new Triple<>("I saw a girl with a telescope.", "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "positive");
|
this(false, 3);
|
||||||
} else {
|
}
|
||||||
if (pos == 1) {
|
|
||||||
pos++;
|
private TestSentenceHelper(boolean alternateOrder, int minibatchSize) throws IOException {
|
||||||
return new Triple<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "I saw a girl with a telescope.", "negative");
|
sentences = new ArrayList<>();
|
||||||
|
labels = new ArrayList<>();
|
||||||
|
tokenizedSentences = new ArrayList<>();
|
||||||
|
tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
||||||
|
if (!alternateOrder) {
|
||||||
|
sentences.add(shortSentence);
|
||||||
|
labels.add("positive");
|
||||||
|
if (minibatchSize > 1) {
|
||||||
|
sentences.add(longSentence);
|
||||||
|
labels.add("negative");
|
||||||
|
if (minibatchSize > 2) {
|
||||||
|
sentences.add(sentenceA);
|
||||||
|
labels.add("positive");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sentences.add(longSentence);
|
||||||
|
labels.add("negative");
|
||||||
|
if (minibatchSize > 1) {
|
||||||
|
sentences.add(shortSentence);
|
||||||
|
labels.add("positive");
|
||||||
|
if (minibatchSize > 2) {
|
||||||
|
sentences.add(sentenceB);
|
||||||
|
labels.add("positive");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
pos++;
|
|
||||||
return new Triple<>("Goodnight noises everywhere", "Goodnight moon", "positive");
|
|
||||||
}
|
}
|
||||||
}
|
for (int i = 0; i < sentences.size(); i++) {
|
||||||
|
List<String> tokenizedSentence = tokenizer.create(sentences.get(i)).getTokens();
|
||||||
@Override
|
if (i == 0)
|
||||||
public void reset() {
|
shortestL = tokenizedSentence.size();
|
||||||
pos = 0;
|
if (tokenizedSentence.size() > longestL)
|
||||||
}
|
longestL = tokenizedSentence.size();
|
||||||
|
if (tokenizedSentence.size() < shortestL)
|
||||||
@Override
|
shortestL = tokenizedSentence.size();
|
||||||
public int totalNumSentences() {
|
tokenizedSentences.add(tokenizedSentence);
|
||||||
return 3;
|
}
|
||||||
}
|
sentenceProvider = new CollectionLabeledSentenceProvider(sentences, labels, null);
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<String> allLabels() {
|
|
||||||
return Arrays.asList("positive", "negative");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numLabelClasses() {
|
|
||||||
return 2;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue