diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java
index c6a88ffb3..b5ca6c91a 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java
@@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -79,6 +80,17 @@ import java.util.Map;
* .build();
* }
*
+ *
+ * Example to use an instantiated iterator for inference:
+ *
+ * {@code
+ * BertIterator b;
+ * List forInference;
+ * Pair featuresAndMask = b.featurizeSentences(forInference);
+ * INDArray[] features = featuresAndMask.getFirst();
+ * INDArray[] featureMasks = featuresAndMask.getSecond();
+ * }
+ *
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.
*
* {@link LengthHandling} configuration:
@@ -89,7 +101,7 @@ import java.util.Map;
* CLIP_ONLY: For any sequences longer than the specified maximum, clip them. If the maximum sequence length in
* a minibatch is shorter than the specified maximum, no padding will occur. For sequences that are shorter than the
* maximum (within the current minibatch) they will be zero padded and masked.
- *
+ *
* {@link FeatureArrays} configuration:
* Determines what arrays should be included.
* INDICES_MASK: Indices array and mask array only, no segment ID array. Returns 1 feature array, 1 feature mask array (plus labels).
@@ -107,8 +119,11 @@ import java.util.Map;
public class BertIterator implements MultiDataSetIterator {
public enum Task {UNSUPERVISED, SEQ_CLASSIFICATION}
+
public enum LengthHandling {FIXED_LENGTH, ANY_LENGTH, CLIP_ONLY}
+
public enum FeatureArrays {INDICES_MASK, INDICES_MASK_SEGMENTID}
+
public enum UnsupervisedLabelFormat {RANK2_IDX, RANK3_NCL, RANK3_LNC}
protected Task task;
@@ -116,12 +131,13 @@ public class BertIterator implements MultiDataSetIterator {
protected int maxTokens = -1;
protected int minibatchSize = 32;
protected boolean padMinibatches = false;
- @Getter @Setter
+ @Getter
+ @Setter
protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null;
protected LengthHandling lengthHandling;
protected FeatureArrays featureArrays;
- protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects?
+ protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects?
protected BertSequenceMasker masker = null;
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
protected String maskToken;
@@ -130,7 +146,7 @@ public class BertIterator implements MultiDataSetIterator {
protected List vocabKeysAsList;
- protected BertIterator(Builder b){
+ protected BertIterator(Builder b) {
this.task = b.task;
this.tokenizerFactory = b.tokenizerFactory;
this.maxTokens = b.maxTokens;
@@ -166,10 +182,10 @@ public class BertIterator implements MultiDataSetIterator {
public MultiDataSet next(int num) {
Preconditions.checkState(hasNext(), "No next element available");
- List> list = new ArrayList<>(num);
- int count = 0;
- if(sentenceProvider != null){
- while(sentenceProvider.hasNext() && count++ < num) {
+ List> list = new ArrayList<>(num);
+ int mbSize = 0;
+ if (sentenceProvider != null) {
+ while (sentenceProvider.hasNext() && mbSize++ < num) {
list.add(sentenceProvider.nextSentence());
}
} else {
@@ -177,41 +193,60 @@ public class BertIterator implements MultiDataSetIterator {
throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
}
- //Get and tokenize the sentences for this minibatch
- List, String>> tokenizedSentences = new ArrayList<>(num);
- int longestSeq = -1;
- for(Pair p : list){
- List tokens = tokenizeSentence(p.getFirst());
- tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
- longestSeq = Math.max(longestSeq, tokens.size());
- }
- //Determine output array length...
- int outLength;
- switch (lengthHandling){
- case FIXED_LENGTH:
- outLength = maxTokens;
- break;
- case ANY_LENGTH:
- outLength = longestSeq;
- break;
- case CLIP_ONLY:
- outLength = Math.min(maxTokens, longestSeq);
- break;
- default:
- throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
- }
+ Pair, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list);
+ List, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
+ int outLength = outLTokenizedSentencesPair.getLeft();
- int mb = tokenizedSentences.size();
- int mbPadded = padMinibatches ? minibatchSize : mb;
+ Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength);
+ INDArray[] featureArray = featuresAndMaskArraysPair.getFirst();
+ INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond();
+
+
+ Pair labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength);
+ INDArray[] labelArray = labelsAndMaskArraysPair.getFirst();
+ INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond();
+
+ org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(featureArray, labelArray, featureMaskArray, labelMaskArray);
+ if (preProcessor != null)
+ preProcessor.preProcess(mds);
+
+ return mds;
+ }
+
+
+ /**
+ * For use during inference. Will convert a given list of sentences to features and feature masks as appropriate.
+ *
+ * @param listOnlySentences
+ * @return Pair of INDArrays[], first element is feature arrays and the second is the masks array
+ */
+ public Pair featurizeSentences(List listOnlySentences) {
+
+ List> sentencesWithNullLabel = addDummyLabel(listOnlySentences);
+
+ Pair, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel);
+ List, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
+ int outLength = outLTokenizedSentencesPair.getLeft();
+
+ Pair featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength);
+ if (preProcessor != null) {
+ MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null);
+ preProcessor.preProcess(dummyMDS);
+ return new Pair(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
+ }
+ return convertMiniBatchFeatures(tokenizedSentences, outLength);
+ }
+
+ private Pair convertMiniBatchFeatures(List, String>> tokenizedSentences, int outLength) {
+ int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size();
int[][] outIdxs = new int[mbPadded][outLength];
int[][] outMask = new int[mbPadded][outLength];
-
- for( int i=0; i,String> p = tokenizedSentences.get(i);
+ for (int i = 0; i < tokenizedSentences.size(); i++) {
+ Pair, String> p = tokenizedSentences.get(i);
List t = p.getFirst();
- for( int j=0; j(f, fm);
+ }
+ private Pair, String>>> tokenizeMiniBatch(List> list) {
+ //Get and tokenize the sentences for this minibatch
+ List, String>> tokenizedSentences = new ArrayList<>(list.size());
+ int longestSeq = -1;
+ for (Pair p : list) {
+ List tokens = tokenizeSentence(p.getFirst());
+ tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
+ longestSeq = Math.max(longestSeq, tokens.size());
+ }
+
+ //Determine output array length...
+ int outLength;
+ switch (lengthHandling) {
+ case FIXED_LENGTH:
+ outLength = maxTokens;
+ break;
+ case ANY_LENGTH:
+ outLength = longestSeq;
+ break;
+ case CLIP_ONLY:
+ outLength = Math.min(maxTokens, longestSeq);
+ break;
+ default:
+ throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
+ }
+ return new Pair<>(outLength, tokenizedSentences);
+ }
+
+ private Pair convertMiniBatchLabels(List, String>> tokenizedSentences, INDArray[] featureArray, int outLength) {
INDArray[] l = new INDArray[1];
INDArray[] lm;
- if(task == Task.SEQ_CLASSIFICATION){
+ int mbSize = tokenizedSentences.size();
+ int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size();
+ if (task == Task.SEQ_CLASSIFICATION) {
//Sequence classification task: output is 2d, one-hot, shape [minibatch, numClasses]
int numClasses;
int[] classLabels = new int[mbPadded];
- if(sentenceProvider != null){
+ if (sentenceProvider != null) {
numClasses = sentenceProvider.numLabelClasses();
List labels = sentenceProvider.allLabels();
- for(int i=0; i= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
@@ -252,21 +320,21 @@ public class BertIterator implements MultiDataSetIterator {
throw new RuntimeException();
}
l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses);
- for( int i=0; i e : vocabMap.entrySet()){
+ for (Map.Entry e : vocabMap.entrySet()) {
arr[e.getValue()] = e.getKey();
}
vocabKeysAsList = Arrays.asList(arr);
@@ -276,31 +344,31 @@ public class BertIterator implements MultiDataSetIterator {
int vocabSize = vocabMap.size();
INDArray labelArr;
INDArray lMask = Nd4j.zeros(DataType.INT, mbPadded, outLength);
- if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){
+ if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
labelArr = Nd4j.create(DataType.INT, mbPadded, outLength);
- } else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){
+ } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength);
- } else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){
+ } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) {
labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize);
} else {
throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat);
}
- for( int i=0; i tokens = tokenizedSentences.get(i).getFirst();
- Pair,boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList);
+ Pair, boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList);
List maskedTokens = p.getFirst();
boolean[] predictionTarget = p.getSecond();
int seqLen = Math.min(predictionTarget.length, outLength);
- for(int j=0; j(l, lm);
}
private List tokenizeSentence(String sentence) {
Tokenizer t = tokenizerFactory.create(sentence);
List tokens = new ArrayList<>();
- if(prependToken != null)
+ if (prependToken != null)
tokens.add(prependToken);
while (t.hasMoreTokens()) {
@@ -341,6 +405,16 @@ public class BertIterator implements MultiDataSetIterator {
return tokens;
}
+
+ private List> addDummyLabel(List listOnlySentences) {
+ List> list = new ArrayList<>(listOnlySentences.size());
+ for (String s : listOnlySentences) {
+ list.add(new Pair(s, null));
+ }
+ return list;
+ }
+
+
@Override
public boolean resetSupported() {
return true;
@@ -353,12 +427,12 @@ public class BertIterator implements MultiDataSetIterator {
@Override
public void reset() {
- if(sentenceProvider != null){
+ if (sentenceProvider != null) {
sentenceProvider.reset();
}
}
- public static Builder builder(){
+ public static Builder builder() {
return new Builder();
}
@@ -373,7 +447,7 @@ public class BertIterator implements MultiDataSetIterator {
protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null;
protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
- protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects?
+ protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects?
protected BertSequenceMasker masker = new BertMaskedLMMasker();
protected UnsupervisedLabelFormat unsupervisedLabelFormat;
protected String maskToken;
@@ -382,7 +456,7 @@ public class BertIterator implements MultiDataSetIterator {
/**
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
*/
- public Builder task(Task task){
+ public Builder task(Task task) {
this.task = task;
return this;
}
@@ -392,18 +466,19 @@ public class BertIterator implements MultiDataSetIterator {
* For BERT, typically {@link org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory}
* is used
*/
- public Builder tokenizer(TokenizerFactory tokenizerFactory){
+ public Builder tokenizer(TokenizerFactory tokenizerFactory) {
this.tokenizerFactory = tokenizerFactory;
return this;
}
/**
* Specifies how the sequence length of the output data should be handled. See {@link BertIterator} for more details.
- * @param lengthHandling Length handling
- * @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH}
+ *
+ * @param lengthHandling Length handling
+ * @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH}
* @return
*/
- public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength){
+ public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength) {
this.lengthHandling = lengthHandling;
this.maxTokens = maxLength;
return this;
@@ -412,9 +487,10 @@ public class BertIterator implements MultiDataSetIterator {
/**
* Minibatch size to use (number of examples to train on for each iteration)
* See also: {@link #padMinibatches}
- * @param minibatchSize Minibatch size
+ *
+ * @param minibatchSize Minibatch size
*/
- public Builder minibatchSize(int minibatchSize){
+ public Builder minibatchSize(int minibatchSize) {
this.minibatchSize = minibatchSize;
return this;
}
@@ -429,7 +505,7 @@ public class BertIterator implements MultiDataSetIterator {
* Both options should result in exactly the same model. However, some BERT implementations may require exactly an
* exact number of examples in all minibatches to function.
*/
- public Builder padMinibatches(boolean padMinibatches){
+ public Builder padMinibatches(boolean padMinibatches) {
this.padMinibatches = padMinibatches;
return this;
}
@@ -437,7 +513,7 @@ public class BertIterator implements MultiDataSetIterator {
/**
* Set the preprocessor to be used on the MultiDataSets before returning them. Default: none (null)
*/
- public Builder preProcessor(MultiDataSetPreProcessor preProcessor){
+ public Builder preProcessor(MultiDataSetPreProcessor preProcessor) {
this.preProcessor = preProcessor;
return this;
}
@@ -446,7 +522,7 @@ public class BertIterator implements MultiDataSetIterator {
* Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised
* use case, the labels will be ignored.
*/
- public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider){
+ public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) {
this.sentenceProvider = sentenceProvider;
return this;
}
@@ -454,7 +530,7 @@ public class BertIterator implements MultiDataSetIterator {
/**
* Specify what arrays should be returned. See {@link BertIterator} for more details.
*/
- public Builder featureArrays(FeatureArrays featureArrays){
+ public Builder featureArrays(FeatureArrays featureArrays) {
this.featureArrays = featureArrays;
return this;
}
@@ -465,7 +541,7 @@ public class BertIterator implements MultiDataSetIterator {
* If using {@link BertWordPieceTokenizerFactory},
* this can be obtained using {@link BertWordPieceTokenizerFactory#getVocab()}
*/
- public Builder vocabMap(Map vocabMap){
+ public Builder vocabMap(Map vocabMap) {
this.vocabMap = vocabMap;
return this;
}
@@ -475,7 +551,7 @@ public class BertIterator implements MultiDataSetIterator {
* masked language model. This can be used to customize how the masking is performed.
* Default: {@link BertMaskedLMMasker}
*/
- public Builder masker(BertSequenceMasker masker){
+ public Builder masker(BertSequenceMasker masker) {
this.masker = masker;
return this;
}
@@ -485,7 +561,7 @@ public class BertIterator implements MultiDataSetIterator {
* masked language model. Used to specify the format that the labels should be returned in.
* See {@link BertIterator} for more details.
*/
- public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat){
+ public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat) {
this.unsupervisedLabelFormat = labelFormat;
return this;
}
@@ -497,7 +573,7 @@ public class BertIterator implements MultiDataSetIterator {
* the exact behaviour will depend on what masker is used.
* Note that this must be in the vocabulary map set in {@link #vocabMap}
*/
- public Builder maskToken(String maskToken){
+ public Builder maskToken(String maskToken) {
this.maskToken = maskToken;
return this;
}
@@ -510,12 +586,12 @@ public class BertIterator implements MultiDataSetIterator {
*
* @param prependToken The token to start each sequence with (null: no token will be prepended)
*/
- public Builder prependToken(String prependToken){
+ public Builder prependToken(String prependToken) {
this.prependToken = prependToken;
return this;
}
- public BertIterator build(){
+ public BertIterator build() {
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
Preconditions.checkState(vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map) to set");
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java
index 90879f858..d4be5e352 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java
@@ -26,7 +26,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
-import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.resources.Resources;
@@ -34,10 +33,7 @@ import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
+import java.util.*;
import static org.junit.Assert.*;
@@ -54,6 +50,9 @@ public class TestBertIterator extends BaseDL4JTest {
String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
+ List forInference = new ArrayList<>();
+ forInference.add(toTokenize1);
+ forInference.add(toTokenize2);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
BertIterator b = BertIterator.builder()
@@ -100,12 +99,15 @@ public class TestBertIterator extends BaseDL4JTest {
assertEquals(expF, mds.getFeatures(0));
assertEquals(expM, mds.getFeaturesMaskArray(0));
+ assertEquals(expF,b.featurizeSentences(forInference).getFirst()[0]);
+ assertEquals(expM,b.featurizeSentences(forInference).getSecond()[0]);
assertFalse(b.hasNext());
b.reset();
assertTrue(b.hasNext());
MultiDataSet mds2 = b.next();
+ forInference.set(0,toTokenize2);
//Same thing, but with segment ID also
b = BertIterator.builder()
.tokenizer(t)
@@ -118,9 +120,11 @@ public class TestBertIterator extends BaseDL4JTest {
.build();
mds = b.next();
assertEquals(2, mds.getFeatures().length);
+ assertEquals(2,b.featurizeSentences(forInference).getFirst().length);
//Segment ID should be all 0s for single segment task
INDArray segmentId = expM.like();
assertEquals(segmentId, mds.getFeatures(1));
+ assertEquals(segmentId,b.featurizeSentences(forInference).getFirst()[1]);
}
@Test(timeout = 20000L)
@@ -157,6 +161,9 @@ public class TestBertIterator extends BaseDL4JTest {
public void testLengthHandling() throws Exception {
String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
+ List forInference = new ArrayList<>();
+ forInference.add(toTokenize1);
+ forInference.add(toTokenize2);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
@@ -205,6 +212,8 @@ public class TestBertIterator extends BaseDL4JTest {
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeatures(0));
assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeaturesMaskArray(0));
+ assertEquals(mds.getFeatures(0),b.featurizeSentences(forInference).getFirst()[0]);
+ assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]);
//Clip only: clip to maximum, but don't pad if less
b = BertIterator.builder()
@@ -227,6 +236,9 @@ public class TestBertIterator extends BaseDL4JTest {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
+ List forInference = new ArrayList<>();
+ forInference.add(toTokenize1);
+ forInference.add(toTokenize2);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
@@ -288,6 +300,9 @@ public class TestBertIterator extends BaseDL4JTest {
assertEquals(expM, mds.getFeaturesMaskArray(0));
assertEquals(expL, mds.getLabels(0));
assertEquals(expLM, mds.getLabelsMaskArray(0));
+
+ assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]);
+ assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
}
private static class TestSentenceProvider implements LabeledSentenceProvider {