* Convenience method for inference with BERT iterator Signed-off-by: eraly <susan.eraly@gmail.com> * Included preprocessing Signed-off-by: eraly <susan.eraly@gmail.com> * Copyright + example Signed-off-by: eraly <susan.eraly@gmail.com>master
parent
7a90a31cfb
commit
823bd0ff88
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -79,6 +80,17 @@ import java.util.Map;
|
||||||
* .build();
|
* .build();
|
||||||
* }
|
* }
|
||||||
* </pre>
|
* </pre>
|
||||||
|
* <br>
|
||||||
|
* <b>Example to use an instantiated iterator for inference:</b><br>
|
||||||
|
* <pre>
|
||||||
|
* {@code
|
||||||
|
* BertIterator b;
|
||||||
|
* List<String> forInference;
|
||||||
|
* Pair<INDArray[],INDArray[]> featuresAndMask = b.featurizeSentences(forInference);
|
||||||
|
* INDArray[] features = featuresAndMask.getFirst();
|
||||||
|
* INDArray[] featureMasks = featuresAndMask.getSecond();
|
||||||
|
* }
|
||||||
|
* </pre>
|
||||||
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br>
|
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br>
|
||||||
* <br>
|
* <br>
|
||||||
* <u><b>{@link LengthHandling} configuration:</b></u><br>
|
* <u><b>{@link LengthHandling} configuration:</b></u><br>
|
||||||
|
@ -89,7 +101,7 @@ import java.util.Map;
|
||||||
* <b>CLIP_ONLY</b>: For any sequences longer than the specified maximum, clip them. If the maximum sequence length in
|
* <b>CLIP_ONLY</b>: For any sequences longer than the specified maximum, clip them. If the maximum sequence length in
|
||||||
* a minibatch is shorter than the specified maximum, no padding will occur. For sequences that are shorter than the
|
* a minibatch is shorter than the specified maximum, no padding will occur. For sequences that are shorter than the
|
||||||
* maximum (within the current minibatch) they will be zero padded and masked.<br>
|
* maximum (within the current minibatch) they will be zero padded and masked.<br>
|
||||||
*<br><br>
|
* <br><br>
|
||||||
* <u><b>{@link FeatureArrays} configuration:</b></u><br>
|
* <u><b>{@link FeatureArrays} configuration:</b></u><br>
|
||||||
* Determines what arrays should be included.<br>
|
* Determines what arrays should be included.<br>
|
||||||
* <b>INDICES_MASK</b>: Indices array and mask array only, no segment ID array. Returns 1 feature array, 1 feature mask array (plus labels).<br>
|
* <b>INDICES_MASK</b>: Indices array and mask array only, no segment ID array. Returns 1 feature array, 1 feature mask array (plus labels).<br>
|
||||||
|
@ -107,8 +119,11 @@ import java.util.Map;
|
||||||
public class BertIterator implements MultiDataSetIterator {
|
public class BertIterator implements MultiDataSetIterator {
|
||||||
|
|
||||||
public enum Task {UNSUPERVISED, SEQ_CLASSIFICATION}
|
public enum Task {UNSUPERVISED, SEQ_CLASSIFICATION}
|
||||||
|
|
||||||
public enum LengthHandling {FIXED_LENGTH, ANY_LENGTH, CLIP_ONLY}
|
public enum LengthHandling {FIXED_LENGTH, ANY_LENGTH, CLIP_ONLY}
|
||||||
|
|
||||||
public enum FeatureArrays {INDICES_MASK, INDICES_MASK_SEGMENTID}
|
public enum FeatureArrays {INDICES_MASK, INDICES_MASK_SEGMENTID}
|
||||||
|
|
||||||
public enum UnsupervisedLabelFormat {RANK2_IDX, RANK3_NCL, RANK3_LNC}
|
public enum UnsupervisedLabelFormat {RANK2_IDX, RANK3_NCL, RANK3_LNC}
|
||||||
|
|
||||||
protected Task task;
|
protected Task task;
|
||||||
|
@ -116,12 +131,13 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
protected int maxTokens = -1;
|
protected int maxTokens = -1;
|
||||||
protected int minibatchSize = 32;
|
protected int minibatchSize = 32;
|
||||||
protected boolean padMinibatches = false;
|
protected boolean padMinibatches = false;
|
||||||
@Getter @Setter
|
@Getter
|
||||||
|
@Setter
|
||||||
protected MultiDataSetPreProcessor preProcessor;
|
protected MultiDataSetPreProcessor preProcessor;
|
||||||
protected LabeledSentenceProvider sentenceProvider = null;
|
protected LabeledSentenceProvider sentenceProvider = null;
|
||||||
protected LengthHandling lengthHandling;
|
protected LengthHandling lengthHandling;
|
||||||
protected FeatureArrays featureArrays;
|
protected FeatureArrays featureArrays;
|
||||||
protected Map<String,Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects?
|
protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects?
|
||||||
protected BertSequenceMasker masker = null;
|
protected BertSequenceMasker masker = null;
|
||||||
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
|
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
|
||||||
protected String maskToken;
|
protected String maskToken;
|
||||||
|
@ -130,7 +146,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
|
|
||||||
protected List<String> vocabKeysAsList;
|
protected List<String> vocabKeysAsList;
|
||||||
|
|
||||||
protected BertIterator(Builder b){
|
protected BertIterator(Builder b) {
|
||||||
this.task = b.task;
|
this.task = b.task;
|
||||||
this.tokenizerFactory = b.tokenizerFactory;
|
this.tokenizerFactory = b.tokenizerFactory;
|
||||||
this.maxTokens = b.maxTokens;
|
this.maxTokens = b.maxTokens;
|
||||||
|
@ -166,10 +182,10 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
public MultiDataSet next(int num) {
|
public MultiDataSet next(int num) {
|
||||||
Preconditions.checkState(hasNext(), "No next element available");
|
Preconditions.checkState(hasNext(), "No next element available");
|
||||||
|
|
||||||
List<Pair<String,String>> list = new ArrayList<>(num);
|
List<Pair<String, String>> list = new ArrayList<>(num);
|
||||||
int count = 0;
|
int mbSize = 0;
|
||||||
if(sentenceProvider != null){
|
if (sentenceProvider != null) {
|
||||||
while(sentenceProvider.hasNext() && count++ < num) {
|
while (sentenceProvider.hasNext() && mbSize++ < num) {
|
||||||
list.add(sentenceProvider.nextSentence());
|
list.add(sentenceProvider.nextSentence());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -177,41 +193,60 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
|
throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
//Get and tokenize the sentences for this minibatch
|
|
||||||
List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(num);
|
Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list);
|
||||||
int longestSeq = -1;
|
List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
|
||||||
for(Pair<String,String> p : list){
|
int outLength = outLTokenizedSentencesPair.getLeft();
|
||||||
List<String> tokens = tokenizeSentence(p.getFirst());
|
|
||||||
tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
|
Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength);
|
||||||
longestSeq = Math.max(longestSeq, tokens.size());
|
INDArray[] featureArray = featuresAndMaskArraysPair.getFirst();
|
||||||
|
INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond();
|
||||||
|
|
||||||
|
|
||||||
|
Pair<INDArray[], INDArray[]> labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength);
|
||||||
|
INDArray[] labelArray = labelsAndMaskArraysPair.getFirst();
|
||||||
|
INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond();
|
||||||
|
|
||||||
|
org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(featureArray, labelArray, featureMaskArray, labelMaskArray);
|
||||||
|
if (preProcessor != null)
|
||||||
|
preProcessor.preProcess(mds);
|
||||||
|
|
||||||
|
return mds;
|
||||||
}
|
}
|
||||||
|
|
||||||
//Determine output array length...
|
|
||||||
int outLength;
|
/**
|
||||||
switch (lengthHandling){
|
* For use during inference. Will convert a given list of sentences to features and feature masks as appropriate.
|
||||||
case FIXED_LENGTH:
|
*
|
||||||
outLength = maxTokens;
|
* @param listOnlySentences
|
||||||
break;
|
* @return Pair of INDArrays[], first element is feature arrays and the second is the masks array
|
||||||
case ANY_LENGTH:
|
*/
|
||||||
outLength = longestSeq;
|
public Pair<INDArray[], INDArray[]> featurizeSentences(List<String> listOnlySentences) {
|
||||||
break;
|
|
||||||
case CLIP_ONLY:
|
List<Pair<String, String>> sentencesWithNullLabel = addDummyLabel(listOnlySentences);
|
||||||
outLength = Math.min(maxTokens, longestSeq);
|
|
||||||
break;
|
Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel);
|
||||||
default:
|
List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
|
||||||
throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
|
int outLength = outLTokenizedSentencesPair.getLeft();
|
||||||
|
|
||||||
|
Pair<INDArray[], INDArray[]> featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength);
|
||||||
|
if (preProcessor != null) {
|
||||||
|
MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null);
|
||||||
|
preProcessor.preProcess(dummyMDS);
|
||||||
|
return new Pair<INDArray[],INDArray[]>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
|
||||||
|
}
|
||||||
|
return convertMiniBatchFeatures(tokenizedSentences, outLength);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mb = tokenizedSentences.size();
|
private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> tokenizedSentences, int outLength) {
|
||||||
int mbPadded = padMinibatches ? minibatchSize : mb;
|
int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size();
|
||||||
int[][] outIdxs = new int[mbPadded][outLength];
|
int[][] outIdxs = new int[mbPadded][outLength];
|
||||||
int[][] outMask = new int[mbPadded][outLength];
|
int[][] outMask = new int[mbPadded][outLength];
|
||||||
|
for (int i = 0; i < tokenizedSentences.size(); i++) {
|
||||||
for( int i=0; i<tokenizedSentences.size(); i++ ){
|
Pair<List<String>, String> p = tokenizedSentences.get(i);
|
||||||
Pair<List<String>,String> p = tokenizedSentences.get(i);
|
|
||||||
List<String> t = p.getFirst();
|
List<String> t = p.getFirst();
|
||||||
for( int j=0; j<outLength && j<t.size(); j++ ){
|
for (int j = 0; j < outLength && j < t.size(); j++) {
|
||||||
Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encontered: token \"%s\" is not in vocabulary", t.get(j));
|
Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j));
|
||||||
int idx = vocabMap.get(t.get(j));
|
int idx = vocabMap.get(t.get(j));
|
||||||
outIdxs[i][j] = idx;
|
outIdxs[i][j] = idx;
|
||||||
outMask[i][j] = 1;
|
outMask[i][j] = 1;
|
||||||
|
@ -224,7 +259,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
INDArray outSegmentIdArr;
|
INDArray outSegmentIdArr;
|
||||||
INDArray[] f;
|
INDArray[] f;
|
||||||
INDArray[] fm;
|
INDArray[] fm;
|
||||||
if(featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID){
|
if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
|
||||||
//For now: always segment index 0 (only single s sequence input supported)
|
//For now: always segment index 0 (only single s sequence input supported)
|
||||||
outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength);
|
outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength);
|
||||||
f = new INDArray[]{outIdxsArr, outSegmentIdArr};
|
f = new INDArray[]{outIdxsArr, outSegmentIdArr};
|
||||||
|
@ -233,17 +268,50 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
f = new INDArray[]{outIdxsArr};
|
f = new INDArray[]{outIdxsArr};
|
||||||
fm = new INDArray[]{outMaskArr};
|
fm = new INDArray[]{outMaskArr};
|
||||||
}
|
}
|
||||||
|
return new Pair<>(f, fm);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Pair<Integer, List<Pair<List<String>, String>>> tokenizeMiniBatch(List<Pair<String, String>> list) {
|
||||||
|
//Get and tokenize the sentences for this minibatch
|
||||||
|
List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(list.size());
|
||||||
|
int longestSeq = -1;
|
||||||
|
for (Pair<String, String> p : list) {
|
||||||
|
List<String> tokens = tokenizeSentence(p.getFirst());
|
||||||
|
tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
|
||||||
|
longestSeq = Math.max(longestSeq, tokens.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
//Determine output array length...
|
||||||
|
int outLength;
|
||||||
|
switch (lengthHandling) {
|
||||||
|
case FIXED_LENGTH:
|
||||||
|
outLength = maxTokens;
|
||||||
|
break;
|
||||||
|
case ANY_LENGTH:
|
||||||
|
outLength = longestSeq;
|
||||||
|
break;
|
||||||
|
case CLIP_ONLY:
|
||||||
|
outLength = Math.min(maxTokens, longestSeq);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
|
||||||
|
}
|
||||||
|
return new Pair<>(outLength, tokenizedSentences);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Pair<INDArray[], INDArray[]> convertMiniBatchLabels(List<Pair<List<String>, String>> tokenizedSentences, INDArray[] featureArray, int outLength) {
|
||||||
INDArray[] l = new INDArray[1];
|
INDArray[] l = new INDArray[1];
|
||||||
INDArray[] lm;
|
INDArray[] lm;
|
||||||
if(task == Task.SEQ_CLASSIFICATION){
|
int mbSize = tokenizedSentences.size();
|
||||||
|
int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size();
|
||||||
|
if (task == Task.SEQ_CLASSIFICATION) {
|
||||||
//Sequence classification task: output is 2d, one-hot, shape [minibatch, numClasses]
|
//Sequence classification task: output is 2d, one-hot, shape [minibatch, numClasses]
|
||||||
int numClasses;
|
int numClasses;
|
||||||
int[] classLabels = new int[mbPadded];
|
int[] classLabels = new int[mbPadded];
|
||||||
if(sentenceProvider != null){
|
if (sentenceProvider != null) {
|
||||||
numClasses = sentenceProvider.numLabelClasses();
|
numClasses = sentenceProvider.numLabelClasses();
|
||||||
List<String> labels = sentenceProvider.allLabels();
|
List<String> labels = sentenceProvider.allLabels();
|
||||||
for(int i=0; i<mb; i++ ){
|
for (int i = 0; i < mbSize; i++) {
|
||||||
String lbl = tokenizedSentences.get(i).getRight();
|
String lbl = tokenizedSentences.get(i).getRight();
|
||||||
classLabels[i] = labels.indexOf(lbl);
|
classLabels[i] = labels.indexOf(lbl);
|
||||||
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
|
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
|
||||||
|
@ -252,21 +320,21 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
}
|
}
|
||||||
l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses);
|
l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses);
|
||||||
for( int i=0; i<mb; i++ ){
|
for (int i = 0; i < mbSize; i++) {
|
||||||
l[0].putScalar(i, classLabels[i], 1.0);
|
l[0].putScalar(i, classLabels[i], 1.0);
|
||||||
}
|
}
|
||||||
lm = null;
|
lm = null;
|
||||||
if(padMinibatches && mb != mbPadded){
|
if (padMinibatches && mbSize != mbPadded) {
|
||||||
INDArray a = Nd4j.zeros(DataType.FLOAT, mbPadded, 1);
|
INDArray a = Nd4j.zeros(DataType.FLOAT, mbPadded, 1);
|
||||||
lm = new INDArray[]{a};
|
lm = new INDArray[]{a};
|
||||||
a.get(NDArrayIndex.interval(0, mb), NDArrayIndex.all()).assign(1);
|
a.get(NDArrayIndex.interval(0, mbSize), NDArrayIndex.all()).assign(1);
|
||||||
}
|
}
|
||||||
} else if(task == Task.UNSUPERVISED){
|
} else if (task == Task.UNSUPERVISED) {
|
||||||
//Unsupervised, masked language model task
|
//Unsupervised, masked language model task
|
||||||
//Output is either 2d, or 3d depending on settings
|
//Output is either 2d, or 3d depending on settings
|
||||||
if(vocabKeysAsList == null){
|
if (vocabKeysAsList == null) {
|
||||||
String[] arr = new String[vocabMap.size()];
|
String[] arr = new String[vocabMap.size()];
|
||||||
for(Map.Entry<String,Integer> e : vocabMap.entrySet()){
|
for (Map.Entry<String, Integer> e : vocabMap.entrySet()) {
|
||||||
arr[e.getValue()] = e.getKey();
|
arr[e.getValue()] = e.getKey();
|
||||||
}
|
}
|
||||||
vocabKeysAsList = Arrays.asList(arr);
|
vocabKeysAsList = Arrays.asList(arr);
|
||||||
|
@ -276,31 +344,31 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
int vocabSize = vocabMap.size();
|
int vocabSize = vocabMap.size();
|
||||||
INDArray labelArr;
|
INDArray labelArr;
|
||||||
INDArray lMask = Nd4j.zeros(DataType.INT, mbPadded, outLength);
|
INDArray lMask = Nd4j.zeros(DataType.INT, mbPadded, outLength);
|
||||||
if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){
|
if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
|
||||||
labelArr = Nd4j.create(DataType.INT, mbPadded, outLength);
|
labelArr = Nd4j.create(DataType.INT, mbPadded, outLength);
|
||||||
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){
|
} else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
|
||||||
labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength);
|
labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength);
|
||||||
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){
|
} else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) {
|
||||||
labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize);
|
labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize);
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat);
|
throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat);
|
||||||
}
|
}
|
||||||
|
|
||||||
for( int i=0; i<mb; i++ ){
|
for (int i = 0; i < mbSize; i++) {
|
||||||
List<String> tokens = tokenizedSentences.get(i).getFirst();
|
List<String> tokens = tokenizedSentences.get(i).getFirst();
|
||||||
Pair<List<String>,boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList);
|
Pair<List<String>, boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList);
|
||||||
List<String> maskedTokens = p.getFirst();
|
List<String> maskedTokens = p.getFirst();
|
||||||
boolean[] predictionTarget = p.getSecond();
|
boolean[] predictionTarget = p.getSecond();
|
||||||
int seqLen = Math.min(predictionTarget.length, outLength);
|
int seqLen = Math.min(predictionTarget.length, outLength);
|
||||||
for(int j=0; j<seqLen; j++ ){
|
for (int j = 0; j < seqLen; j++) {
|
||||||
if(predictionTarget[j]){
|
if (predictionTarget[j]) {
|
||||||
String oldToken = tokenizedSentences.get(i).getFirst().get(j); //This is target
|
String oldToken = tokenizedSentences.get(i).getFirst().get(j); //This is target
|
||||||
int targetTokenIdx = vocabMap.get(oldToken);
|
int targetTokenIdx = vocabMap.get(oldToken);
|
||||||
if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){
|
if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
|
||||||
labelArr.putScalar(i, j, targetTokenIdx);
|
labelArr.putScalar(i, j, targetTokenIdx);
|
||||||
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){
|
} else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
|
||||||
labelArr.putScalar(i, j, targetTokenIdx, 1.0);
|
labelArr.putScalar(i, j, targetTokenIdx, 1.0);
|
||||||
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){
|
} else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) {
|
||||||
labelArr.putScalar(j, i, targetTokenIdx, 1.0);
|
labelArr.putScalar(j, i, targetTokenIdx, 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -309,7 +377,8 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
//Also update previously created feature label indexes:
|
//Also update previously created feature label indexes:
|
||||||
String newToken = maskedTokens.get(j);
|
String newToken = maskedTokens.get(j);
|
||||||
int newTokenIdx = vocabMap.get(newToken);
|
int newTokenIdx = vocabMap.get(newToken);
|
||||||
outIdxsArr.putScalar(i,j,newTokenIdx);
|
//first element of features is outIdxsArr
|
||||||
|
featureArray[0].putScalar(i, j, newTokenIdx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -319,19 +388,14 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalStateException("Task not yet implemented: " + task);
|
throw new IllegalStateException("Task not yet implemented: " + task);
|
||||||
}
|
}
|
||||||
|
return new Pair<>(l, lm);
|
||||||
org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(f, l, fm, lm);
|
|
||||||
if(preProcessor != null)
|
|
||||||
preProcessor.preProcess(mds);
|
|
||||||
|
|
||||||
return mds;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<String> tokenizeSentence(String sentence) {
|
private List<String> tokenizeSentence(String sentence) {
|
||||||
Tokenizer t = tokenizerFactory.create(sentence);
|
Tokenizer t = tokenizerFactory.create(sentence);
|
||||||
|
|
||||||
List<String> tokens = new ArrayList<>();
|
List<String> tokens = new ArrayList<>();
|
||||||
if(prependToken != null)
|
if (prependToken != null)
|
||||||
tokens.add(prependToken);
|
tokens.add(prependToken);
|
||||||
|
|
||||||
while (t.hasMoreTokens()) {
|
while (t.hasMoreTokens()) {
|
||||||
|
@ -341,6 +405,16 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private List<Pair<String, String>> addDummyLabel(List<String> listOnlySentences) {
|
||||||
|
List<Pair<String, String>> list = new ArrayList<>(listOnlySentences.size());
|
||||||
|
for (String s : listOnlySentences) {
|
||||||
|
list.add(new Pair<String, String>(s, null));
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean resetSupported() {
|
public boolean resetSupported() {
|
||||||
return true;
|
return true;
|
||||||
|
@ -353,12 +427,12 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void reset() {
|
public void reset() {
|
||||||
if(sentenceProvider != null){
|
if (sentenceProvider != null) {
|
||||||
sentenceProvider.reset();
|
sentenceProvider.reset();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Builder builder(){
|
public static Builder builder() {
|
||||||
return new Builder();
|
return new Builder();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -373,7 +447,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
protected MultiDataSetPreProcessor preProcessor;
|
protected MultiDataSetPreProcessor preProcessor;
|
||||||
protected LabeledSentenceProvider sentenceProvider = null;
|
protected LabeledSentenceProvider sentenceProvider = null;
|
||||||
protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
|
protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
|
||||||
protected Map<String,Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects?
|
protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects?
|
||||||
protected BertSequenceMasker masker = new BertMaskedLMMasker();
|
protected BertSequenceMasker masker = new BertMaskedLMMasker();
|
||||||
protected UnsupervisedLabelFormat unsupervisedLabelFormat;
|
protected UnsupervisedLabelFormat unsupervisedLabelFormat;
|
||||||
protected String maskToken;
|
protected String maskToken;
|
||||||
|
@ -382,7 +456,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
/**
|
/**
|
||||||
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
|
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
|
||||||
*/
|
*/
|
||||||
public Builder task(Task task){
|
public Builder task(Task task) {
|
||||||
this.task = task;
|
this.task = task;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
@ -392,18 +466,19 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
* For BERT, typically {@link org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory}
|
* For BERT, typically {@link org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory}
|
||||||
* is used
|
* is used
|
||||||
*/
|
*/
|
||||||
public Builder tokenizer(TokenizerFactory tokenizerFactory){
|
public Builder tokenizer(TokenizerFactory tokenizerFactory) {
|
||||||
this.tokenizerFactory = tokenizerFactory;
|
this.tokenizerFactory = tokenizerFactory;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specifies how the sequence length of the output data should be handled. See {@link BertIterator} for more details.
|
* Specifies how the sequence length of the output data should be handled. See {@link BertIterator} for more details.
|
||||||
|
*
|
||||||
* @param lengthHandling Length handling
|
* @param lengthHandling Length handling
|
||||||
* @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH}
|
* @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH}
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength){
|
public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength) {
|
||||||
this.lengthHandling = lengthHandling;
|
this.lengthHandling = lengthHandling;
|
||||||
this.maxTokens = maxLength;
|
this.maxTokens = maxLength;
|
||||||
return this;
|
return this;
|
||||||
|
@ -412,9 +487,10 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
/**
|
/**
|
||||||
* Minibatch size to use (number of examples to train on for each iteration)
|
* Minibatch size to use (number of examples to train on for each iteration)
|
||||||
* See also: {@link #padMinibatches}
|
* See also: {@link #padMinibatches}
|
||||||
|
*
|
||||||
* @param minibatchSize Minibatch size
|
* @param minibatchSize Minibatch size
|
||||||
*/
|
*/
|
||||||
public Builder minibatchSize(int minibatchSize){
|
public Builder minibatchSize(int minibatchSize) {
|
||||||
this.minibatchSize = minibatchSize;
|
this.minibatchSize = minibatchSize;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
@ -429,7 +505,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
* Both options should result in exactly the same model. However, some BERT implementations may require exactly an
|
* Both options should result in exactly the same model. However, some BERT implementations may require exactly an
|
||||||
* exact number of examples in all minibatches to function.
|
* exact number of examples in all minibatches to function.
|
||||||
*/
|
*/
|
||||||
public Builder padMinibatches(boolean padMinibatches){
|
public Builder padMinibatches(boolean padMinibatches) {
|
||||||
this.padMinibatches = padMinibatches;
|
this.padMinibatches = padMinibatches;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
@ -437,7 +513,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
/**
|
/**
|
||||||
* Set the preprocessor to be used on the MultiDataSets before returning them. Default: none (null)
|
* Set the preprocessor to be used on the MultiDataSets before returning them. Default: none (null)
|
||||||
*/
|
*/
|
||||||
public Builder preProcessor(MultiDataSetPreProcessor preProcessor){
|
public Builder preProcessor(MultiDataSetPreProcessor preProcessor) {
|
||||||
this.preProcessor = preProcessor;
|
this.preProcessor = preProcessor;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
@ -446,7 +522,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
* Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised
|
* Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised
|
||||||
* use case, the labels will be ignored.
|
* use case, the labels will be ignored.
|
||||||
*/
|
*/
|
||||||
public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider){
|
public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) {
|
||||||
this.sentenceProvider = sentenceProvider;
|
this.sentenceProvider = sentenceProvider;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
@ -454,7 +530,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
/**
|
/**
|
||||||
* Specify what arrays should be returned. See {@link BertIterator} for more details.
|
* Specify what arrays should be returned. See {@link BertIterator} for more details.
|
||||||
*/
|
*/
|
||||||
public Builder featureArrays(FeatureArrays featureArrays){
|
public Builder featureArrays(FeatureArrays featureArrays) {
|
||||||
this.featureArrays = featureArrays;
|
this.featureArrays = featureArrays;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
@ -465,7 +541,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
* If using {@link BertWordPieceTokenizerFactory},
|
* If using {@link BertWordPieceTokenizerFactory},
|
||||||
* this can be obtained using {@link BertWordPieceTokenizerFactory#getVocab()}
|
* this can be obtained using {@link BertWordPieceTokenizerFactory#getVocab()}
|
||||||
*/
|
*/
|
||||||
public Builder vocabMap(Map<String,Integer> vocabMap){
|
public Builder vocabMap(Map<String, Integer> vocabMap) {
|
||||||
this.vocabMap = vocabMap;
|
this.vocabMap = vocabMap;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
@ -475,7 +551,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
* masked language model. This can be used to customize how the masking is performed.<br>
|
* masked language model. This can be used to customize how the masking is performed.<br>
|
||||||
* Default: {@link BertMaskedLMMasker}
|
* Default: {@link BertMaskedLMMasker}
|
||||||
*/
|
*/
|
||||||
public Builder masker(BertSequenceMasker masker){
|
public Builder masker(BertSequenceMasker masker) {
|
||||||
this.masker = masker;
|
this.masker = masker;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
@ -485,7 +561,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
* masked language model. Used to specify the format that the labels should be returned in.
|
* masked language model. Used to specify the format that the labels should be returned in.
|
||||||
* See {@link BertIterator} for more details.
|
* See {@link BertIterator} for more details.
|
||||||
*/
|
*/
|
||||||
public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat){
|
public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat) {
|
||||||
this.unsupervisedLabelFormat = labelFormat;
|
this.unsupervisedLabelFormat = labelFormat;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
@ -497,7 +573,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
* the exact behaviour will depend on what masker is used.<br>
|
* the exact behaviour will depend on what masker is used.<br>
|
||||||
* Note that this must be in the vocabulary map set in {@link #vocabMap}
|
* Note that this must be in the vocabulary map set in {@link #vocabMap}
|
||||||
*/
|
*/
|
||||||
public Builder maskToken(String maskToken){
|
public Builder maskToken(String maskToken) {
|
||||||
this.maskToken = maskToken;
|
this.maskToken = maskToken;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
@ -510,12 +586,12 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
*
|
*
|
||||||
* @param prependToken The token to start each sequence with (null: no token will be prepended)
|
* @param prependToken The token to start each sequence with (null: no token will be prepended)
|
||||||
*/
|
*/
|
||||||
public Builder prependToken(String prependToken){
|
public Builder prependToken(String prependToken) {
|
||||||
this.prependToken = prependToken;
|
this.prependToken = prependToken;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public BertIterator build(){
|
public BertIterator build() {
|
||||||
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
|
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
|
||||||
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
|
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
|
||||||
Preconditions.checkState(vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map<String,Integer>) to set");
|
Preconditions.checkState(vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map<String,Integer>) to set");
|
||||||
|
|
|
@ -26,7 +26,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.resources.Resources;
|
import org.nd4j.resources.Resources;
|
||||||
|
|
||||||
|
@ -34,10 +33,7 @@ import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.charset.Charset;
|
import java.nio.charset.Charset;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.Arrays;
|
import java.util.*;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
@ -54,6 +50,9 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
|
|
||||||
String toTokenize1 = "I saw a girl with a telescope.";
|
String toTokenize1 = "I saw a girl with a telescope.";
|
||||||
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
||||||
|
List<String> forInference = new ArrayList<>();
|
||||||
|
forInference.add(toTokenize1);
|
||||||
|
forInference.add(toTokenize2);
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
||||||
|
|
||||||
BertIterator b = BertIterator.builder()
|
BertIterator b = BertIterator.builder()
|
||||||
|
@ -100,12 +99,15 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(expF, mds.getFeatures(0));
|
assertEquals(expF, mds.getFeatures(0));
|
||||||
assertEquals(expM, mds.getFeaturesMaskArray(0));
|
assertEquals(expM, mds.getFeaturesMaskArray(0));
|
||||||
|
assertEquals(expF,b.featurizeSentences(forInference).getFirst()[0]);
|
||||||
|
assertEquals(expM,b.featurizeSentences(forInference).getSecond()[0]);
|
||||||
|
|
||||||
assertFalse(b.hasNext());
|
assertFalse(b.hasNext());
|
||||||
b.reset();
|
b.reset();
|
||||||
assertTrue(b.hasNext());
|
assertTrue(b.hasNext());
|
||||||
MultiDataSet mds2 = b.next();
|
MultiDataSet mds2 = b.next();
|
||||||
|
|
||||||
|
forInference.set(0,toTokenize2);
|
||||||
//Same thing, but with segment ID also
|
//Same thing, but with segment ID also
|
||||||
b = BertIterator.builder()
|
b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(t)
|
||||||
|
@ -118,9 +120,11 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
mds = b.next();
|
mds = b.next();
|
||||||
assertEquals(2, mds.getFeatures().length);
|
assertEquals(2, mds.getFeatures().length);
|
||||||
|
assertEquals(2,b.featurizeSentences(forInference).getFirst().length);
|
||||||
//Segment ID should be all 0s for single segment task
|
//Segment ID should be all 0s for single segment task
|
||||||
INDArray segmentId = expM.like();
|
INDArray segmentId = expM.like();
|
||||||
assertEquals(segmentId, mds.getFeatures(1));
|
assertEquals(segmentId, mds.getFeatures(1));
|
||||||
|
assertEquals(segmentId,b.featurizeSentences(forInference).getFirst()[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 20000L)
|
@Test(timeout = 20000L)
|
||||||
|
@ -157,6 +161,9 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
public void testLengthHandling() throws Exception {
|
public void testLengthHandling() throws Exception {
|
||||||
String toTokenize1 = "I saw a girl with a telescope.";
|
String toTokenize1 = "I saw a girl with a telescope.";
|
||||||
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
||||||
|
List<String> forInference = new ArrayList<>();
|
||||||
|
forInference.add(toTokenize1);
|
||||||
|
forInference.add(toTokenize2);
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
||||||
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
|
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
|
||||||
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
|
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
|
||||||
|
@ -205,6 +212,8 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
|
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
|
||||||
assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeatures(0));
|
assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeatures(0));
|
||||||
assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeaturesMaskArray(0));
|
assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeaturesMaskArray(0));
|
||||||
|
assertEquals(mds.getFeatures(0),b.featurizeSentences(forInference).getFirst()[0]);
|
||||||
|
assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]);
|
||||||
|
|
||||||
//Clip only: clip to maximum, but don't pad if less
|
//Clip only: clip to maximum, but don't pad if less
|
||||||
b = BertIterator.builder()
|
b = BertIterator.builder()
|
||||||
|
@ -227,6 +236,9 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
String toTokenize1 = "I saw a girl with a telescope.";
|
String toTokenize1 = "I saw a girl with a telescope.";
|
||||||
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
||||||
|
List<String> forInference = new ArrayList<>();
|
||||||
|
forInference.add(toTokenize1);
|
||||||
|
forInference.add(toTokenize2);
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
||||||
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
|
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
|
||||||
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
|
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
|
||||||
|
@ -288,6 +300,9 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
assertEquals(expM, mds.getFeaturesMaskArray(0));
|
assertEquals(expM, mds.getFeaturesMaskArray(0));
|
||||||
assertEquals(expL, mds.getLabels(0));
|
assertEquals(expL, mds.getLabels(0));
|
||||||
assertEquals(expLM, mds.getLabelsMaskArray(0));
|
assertEquals(expLM, mds.getLabelsMaskArray(0));
|
||||||
|
|
||||||
|
assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]);
|
||||||
|
assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class TestSentenceProvider implements LabeledSentenceProvider {
|
private static class TestSentenceProvider implements LabeledSentenceProvider {
|
||||||
|
|
Loading…
Reference in New Issue