* Convenience method for inference with BERT iterator Signed-off-by: eraly <susan.eraly@gmail.com> * Included preprocessing Signed-off-by: eraly <susan.eraly@gmail.com> * Copyright + example Signed-off-by: eraly <susan.eraly@gmail.com>
This commit is contained in:
		
							parent
							
								
									7a90a31cfb
								
							
						
					
					
						commit
						823bd0ff88
					
				| @ -1,5 +1,6 @@ | |||||||
| /******************************************************************************* | /******************************************************************************* | ||||||
|  * Copyright (c) 2015-2019 Skymind, Inc. |  * Copyright (c) 2015-2019 Skymind, Inc. | ||||||
|  |  * Copyright (c) 2019 Konduit K.K. | ||||||
|  * |  * | ||||||
|  * This program and the accompanying materials are made available under the |  * This program and the accompanying materials are made available under the | ||||||
|  * terms of the Apache License, Version 2.0 which is available at |  * terms of the Apache License, Version 2.0 which is available at | ||||||
| @ -79,6 +80,17 @@ import java.util.Map; | |||||||
|  *              .build(); |  *              .build(); | ||||||
|  * } |  * } | ||||||
|  * </pre> |  * </pre> | ||||||
|  |  * <br> | ||||||
|  |  * <b>Example to use an instantiated iterator for inference:</b><br> | ||||||
|  |  * <pre> | ||||||
|  |  * {@code | ||||||
|  |  *          BertIterator b; | ||||||
|  |  *          List<String> forInference; | ||||||
|  |  *          Pair<INDArray[],INDArray[]> featuresAndMask = b.featurizeSentences(forInference); | ||||||
|  |  *          INDArray[] features = featuresAndMask.getFirst(); | ||||||
|  |  *          INDArray[] featureMasks = featuresAndMask.getSecond(); | ||||||
|  |  * } | ||||||
|  |  * </pre> | ||||||
|  * This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br> |  * This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br> | ||||||
|  * <br> |  * <br> | ||||||
|  * <u><b>{@link LengthHandling} configuration:</b></u><br> |  * <u><b>{@link LengthHandling} configuration:</b></u><br> | ||||||
| @ -89,7 +101,7 @@ import java.util.Map; | |||||||
|  * <b>CLIP_ONLY</b>: For any sequences longer than the specified maximum, clip them. If the maximum sequence length in |  * <b>CLIP_ONLY</b>: For any sequences longer than the specified maximum, clip them. If the maximum sequence length in | ||||||
|  * a minibatch is shorter than the specified maximum, no padding will occur. For sequences that are shorter than the |  * a minibatch is shorter than the specified maximum, no padding will occur. For sequences that are shorter than the | ||||||
|  * maximum (within the current minibatch) they will be zero padded and masked.<br> |  * maximum (within the current minibatch) they will be zero padded and masked.<br> | ||||||
|  *<br><br> |  * <br><br> | ||||||
|  * <u><b>{@link FeatureArrays} configuration:</b></u><br> |  * <u><b>{@link FeatureArrays} configuration:</b></u><br> | ||||||
|  * Determines what arrays should be included.<br> |  * Determines what arrays should be included.<br> | ||||||
|  * <b>INDICES_MASK</b>: Indices array and mask array only, no segment ID array. Returns 1 feature array, 1 feature mask array (plus labels).<br> |  * <b>INDICES_MASK</b>: Indices array and mask array only, no segment ID array. Returns 1 feature array, 1 feature mask array (plus labels).<br> | ||||||
| @ -107,8 +119,11 @@ import java.util.Map; | |||||||
| public class BertIterator implements MultiDataSetIterator { | public class BertIterator implements MultiDataSetIterator { | ||||||
| 
 | 
 | ||||||
|     public enum Task {UNSUPERVISED, SEQ_CLASSIFICATION} |     public enum Task {UNSUPERVISED, SEQ_CLASSIFICATION} | ||||||
|  | 
 | ||||||
|     public enum LengthHandling {FIXED_LENGTH, ANY_LENGTH, CLIP_ONLY} |     public enum LengthHandling {FIXED_LENGTH, ANY_LENGTH, CLIP_ONLY} | ||||||
|  | 
 | ||||||
|     public enum FeatureArrays {INDICES_MASK, INDICES_MASK_SEGMENTID} |     public enum FeatureArrays {INDICES_MASK, INDICES_MASK_SEGMENTID} | ||||||
|  | 
 | ||||||
|     public enum UnsupervisedLabelFormat {RANK2_IDX, RANK3_NCL, RANK3_LNC} |     public enum UnsupervisedLabelFormat {RANK2_IDX, RANK3_NCL, RANK3_LNC} | ||||||
| 
 | 
 | ||||||
|     protected Task task; |     protected Task task; | ||||||
| @ -116,12 +131,13 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|     protected int maxTokens = -1; |     protected int maxTokens = -1; | ||||||
|     protected int minibatchSize = 32; |     protected int minibatchSize = 32; | ||||||
|     protected boolean padMinibatches = false; |     protected boolean padMinibatches = false; | ||||||
|     @Getter @Setter |     @Getter | ||||||
|  |     @Setter | ||||||
|     protected MultiDataSetPreProcessor preProcessor; |     protected MultiDataSetPreProcessor preProcessor; | ||||||
|     protected LabeledSentenceProvider sentenceProvider = null; |     protected LabeledSentenceProvider sentenceProvider = null; | ||||||
|     protected LengthHandling lengthHandling; |     protected LengthHandling lengthHandling; | ||||||
|     protected FeatureArrays featureArrays; |     protected FeatureArrays featureArrays; | ||||||
|     protected Map<String,Integer> vocabMap;   //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects? |     protected Map<String, Integer> vocabMap;   //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects? | ||||||
|     protected BertSequenceMasker masker = null; |     protected BertSequenceMasker masker = null; | ||||||
|     protected UnsupervisedLabelFormat unsupervisedLabelFormat = null; |     protected UnsupervisedLabelFormat unsupervisedLabelFormat = null; | ||||||
|     protected String maskToken; |     protected String maskToken; | ||||||
| @ -130,7 +146,7 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
| 
 | 
 | ||||||
|     protected List<String> vocabKeysAsList; |     protected List<String> vocabKeysAsList; | ||||||
| 
 | 
 | ||||||
|     protected BertIterator(Builder b){ |     protected BertIterator(Builder b) { | ||||||
|         this.task = b.task; |         this.task = b.task; | ||||||
|         this.tokenizerFactory = b.tokenizerFactory; |         this.tokenizerFactory = b.tokenizerFactory; | ||||||
|         this.maxTokens = b.maxTokens; |         this.maxTokens = b.maxTokens; | ||||||
| @ -166,10 +182,10 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|     public MultiDataSet next(int num) { |     public MultiDataSet next(int num) { | ||||||
|         Preconditions.checkState(hasNext(), "No next element available"); |         Preconditions.checkState(hasNext(), "No next element available"); | ||||||
| 
 | 
 | ||||||
|         List<Pair<String,String>> list = new ArrayList<>(num); |         List<Pair<String, String>> list = new ArrayList<>(num); | ||||||
|         int count = 0; |         int mbSize = 0; | ||||||
|         if(sentenceProvider != null){ |         if (sentenceProvider != null) { | ||||||
|             while(sentenceProvider.hasNext() && count++ < num) { |             while (sentenceProvider.hasNext() && mbSize++ < num) { | ||||||
|                 list.add(sentenceProvider.nextSentence()); |                 list.add(sentenceProvider.nextSentence()); | ||||||
|             } |             } | ||||||
|         } else { |         } else { | ||||||
| @ -177,41 +193,60 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|             throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented"); |             throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented"); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         //Get and tokenize the sentences for this minibatch |  | ||||||
|         List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(num); |  | ||||||
|         int longestSeq = -1; |  | ||||||
|         for(Pair<String,String> p : list){ |  | ||||||
|             List<String> tokens = tokenizeSentence(p.getFirst()); |  | ||||||
|             tokenizedSentences.add(new Pair<>(tokens, p.getSecond())); |  | ||||||
|             longestSeq = Math.max(longestSeq, tokens.size()); |  | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
|         //Determine output array length... |         Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list); | ||||||
|         int outLength; |         List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight(); | ||||||
|         switch (lengthHandling){ |         int outLength = outLTokenizedSentencesPair.getLeft(); | ||||||
|             case FIXED_LENGTH: |  | ||||||
|                 outLength = maxTokens; |  | ||||||
|                 break; |  | ||||||
|             case ANY_LENGTH: |  | ||||||
|                 outLength = longestSeq; |  | ||||||
|                 break; |  | ||||||
|             case CLIP_ONLY: |  | ||||||
|                 outLength = Math.min(maxTokens, longestSeq); |  | ||||||
|                 break; |  | ||||||
|             default: |  | ||||||
|                 throw new RuntimeException("Not implemented length handling mode: " + lengthHandling); |  | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
|         int mb = tokenizedSentences.size(); |         Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength); | ||||||
|         int mbPadded = padMinibatches ? minibatchSize : mb; |         INDArray[] featureArray = featuresAndMaskArraysPair.getFirst(); | ||||||
|  |         INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond(); | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         Pair<INDArray[], INDArray[]> labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength); | ||||||
|  |         INDArray[] labelArray = labelsAndMaskArraysPair.getFirst(); | ||||||
|  |         INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond(); | ||||||
|  | 
 | ||||||
|  |         org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(featureArray, labelArray, featureMaskArray, labelMaskArray); | ||||||
|  |         if (preProcessor != null) | ||||||
|  |             preProcessor.preProcess(mds); | ||||||
|  | 
 | ||||||
|  |         return mds; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * For use during inference. Will convert a given list of sentences to features and feature masks as appropriate. | ||||||
|  |      * | ||||||
|  |      * @param listOnlySentences | ||||||
|  |      * @return Pair of INDArrays[], first element is feature arrays and the second is the masks array | ||||||
|  |      */ | ||||||
|  |     public Pair<INDArray[], INDArray[]> featurizeSentences(List<String> listOnlySentences) { | ||||||
|  | 
 | ||||||
|  |         List<Pair<String, String>> sentencesWithNullLabel = addDummyLabel(listOnlySentences); | ||||||
|  | 
 | ||||||
|  |         Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel); | ||||||
|  |         List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight(); | ||||||
|  |         int outLength = outLTokenizedSentencesPair.getLeft(); | ||||||
|  | 
 | ||||||
|  |         Pair<INDArray[], INDArray[]> featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength); | ||||||
|  |         if (preProcessor != null) { | ||||||
|  |             MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null); | ||||||
|  |             preProcessor.preProcess(dummyMDS); | ||||||
|  |             return new Pair<INDArray[],INDArray[]>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays()); | ||||||
|  |         } | ||||||
|  |         return convertMiniBatchFeatures(tokenizedSentences, outLength); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> tokenizedSentences, int outLength) { | ||||||
|  |         int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size(); | ||||||
|         int[][] outIdxs = new int[mbPadded][outLength]; |         int[][] outIdxs = new int[mbPadded][outLength]; | ||||||
|         int[][] outMask = new int[mbPadded][outLength]; |         int[][] outMask = new int[mbPadded][outLength]; | ||||||
| 
 |         for (int i = 0; i < tokenizedSentences.size(); i++) { | ||||||
|         for( int i=0; i<tokenizedSentences.size(); i++ ){ |             Pair<List<String>, String> p = tokenizedSentences.get(i); | ||||||
|             Pair<List<String>,String> p = tokenizedSentences.get(i); |  | ||||||
|             List<String> t = p.getFirst(); |             List<String> t = p.getFirst(); | ||||||
|             for( int j=0; j<outLength && j<t.size(); j++ ){ |             for (int j = 0; j < outLength && j < t.size(); j++) { | ||||||
|                 Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encontered: token \"%s\" is not in vocabulary", t.get(j)); |                 Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j)); | ||||||
|                 int idx = vocabMap.get(t.get(j)); |                 int idx = vocabMap.get(t.get(j)); | ||||||
|                 outIdxs[i][j] = idx; |                 outIdxs[i][j] = idx; | ||||||
|                 outMask[i][j] = 1; |                 outMask[i][j] = 1; | ||||||
| @ -224,7 +259,7 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|         INDArray outSegmentIdArr; |         INDArray outSegmentIdArr; | ||||||
|         INDArray[] f; |         INDArray[] f; | ||||||
|         INDArray[] fm; |         INDArray[] fm; | ||||||
|         if(featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID){ |         if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) { | ||||||
|             //For now: always segment index 0 (only single s sequence input supported) |             //For now: always segment index 0 (only single s sequence input supported) | ||||||
|             outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength); |             outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength); | ||||||
|             f = new INDArray[]{outIdxsArr, outSegmentIdArr}; |             f = new INDArray[]{outIdxsArr, outSegmentIdArr}; | ||||||
| @ -233,17 +268,50 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|             f = new INDArray[]{outIdxsArr}; |             f = new INDArray[]{outIdxsArr}; | ||||||
|             fm = new INDArray[]{outMaskArr}; |             fm = new INDArray[]{outMaskArr}; | ||||||
|         } |         } | ||||||
|  |         return new Pair<>(f, fm); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|  |     private Pair<Integer, List<Pair<List<String>, String>>> tokenizeMiniBatch(List<Pair<String, String>> list) { | ||||||
|  |         //Get and tokenize the sentences for this minibatch | ||||||
|  |         List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(list.size()); | ||||||
|  |         int longestSeq = -1; | ||||||
|  |         for (Pair<String, String> p : list) { | ||||||
|  |             List<String> tokens = tokenizeSentence(p.getFirst()); | ||||||
|  |             tokenizedSentences.add(new Pair<>(tokens, p.getSecond())); | ||||||
|  |             longestSeq = Math.max(longestSeq, tokens.size()); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         //Determine output array length... | ||||||
|  |         int outLength; | ||||||
|  |         switch (lengthHandling) { | ||||||
|  |             case FIXED_LENGTH: | ||||||
|  |                 outLength = maxTokens; | ||||||
|  |                 break; | ||||||
|  |             case ANY_LENGTH: | ||||||
|  |                 outLength = longestSeq; | ||||||
|  |                 break; | ||||||
|  |             case CLIP_ONLY: | ||||||
|  |                 outLength = Math.min(maxTokens, longestSeq); | ||||||
|  |                 break; | ||||||
|  |             default: | ||||||
|  |                 throw new RuntimeException("Not implemented length handling mode: " + lengthHandling); | ||||||
|  |         } | ||||||
|  |         return new Pair<>(outLength, tokenizedSentences); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private Pair<INDArray[], INDArray[]> convertMiniBatchLabels(List<Pair<List<String>, String>> tokenizedSentences, INDArray[] featureArray, int outLength) { | ||||||
|         INDArray[] l = new INDArray[1]; |         INDArray[] l = new INDArray[1]; | ||||||
|         INDArray[] lm; |         INDArray[] lm; | ||||||
|         if(task == Task.SEQ_CLASSIFICATION){ |         int mbSize = tokenizedSentences.size(); | ||||||
|  |         int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size(); | ||||||
|  |         if (task == Task.SEQ_CLASSIFICATION) { | ||||||
|             //Sequence classification task: output is 2d, one-hot, shape [minibatch, numClasses] |             //Sequence classification task: output is 2d, one-hot, shape [minibatch, numClasses] | ||||||
|             int numClasses; |             int numClasses; | ||||||
|             int[] classLabels = new int[mbPadded]; |             int[] classLabels = new int[mbPadded]; | ||||||
|             if(sentenceProvider != null){ |             if (sentenceProvider != null) { | ||||||
|                 numClasses = sentenceProvider.numLabelClasses(); |                 numClasses = sentenceProvider.numLabelClasses(); | ||||||
|                 List<String> labels = sentenceProvider.allLabels(); |                 List<String> labels = sentenceProvider.allLabels(); | ||||||
|                 for(int i=0; i<mb; i++ ){ |                 for (int i = 0; i < mbSize; i++) { | ||||||
|                     String lbl = tokenizedSentences.get(i).getRight(); |                     String lbl = tokenizedSentences.get(i).getRight(); | ||||||
|                     classLabels[i] = labels.indexOf(lbl); |                     classLabels[i] = labels.indexOf(lbl); | ||||||
|                     Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl); |                     Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl); | ||||||
| @ -252,21 +320,21 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|                 throw new RuntimeException(); |                 throw new RuntimeException(); | ||||||
|             } |             } | ||||||
|             l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses); |             l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses); | ||||||
|             for( int i=0; i<mb; i++ ){ |             for (int i = 0; i < mbSize; i++) { | ||||||
|                 l[0].putScalar(i, classLabels[i], 1.0); |                 l[0].putScalar(i, classLabels[i], 1.0); | ||||||
|             } |             } | ||||||
|             lm = null; |             lm = null; | ||||||
|             if(padMinibatches && mb != mbPadded){ |             if (padMinibatches && mbSize != mbPadded) { | ||||||
|                 INDArray a = Nd4j.zeros(DataType.FLOAT, mbPadded, 1); |                 INDArray a = Nd4j.zeros(DataType.FLOAT, mbPadded, 1); | ||||||
|                 lm = new INDArray[]{a}; |                 lm = new INDArray[]{a}; | ||||||
|                 a.get(NDArrayIndex.interval(0, mb), NDArrayIndex.all()).assign(1); |                 a.get(NDArrayIndex.interval(0, mbSize), NDArrayIndex.all()).assign(1); | ||||||
|             } |             } | ||||||
|         } else if(task == Task.UNSUPERVISED){ |         } else if (task == Task.UNSUPERVISED) { | ||||||
|             //Unsupervised, masked language model task |             //Unsupervised, masked language model task | ||||||
|             //Output is either 2d, or 3d depending on settings |             //Output is either 2d, or 3d depending on settings | ||||||
|             if(vocabKeysAsList == null){ |             if (vocabKeysAsList == null) { | ||||||
|                 String[] arr = new String[vocabMap.size()]; |                 String[] arr = new String[vocabMap.size()]; | ||||||
|                 for(Map.Entry<String,Integer> e : vocabMap.entrySet()){ |                 for (Map.Entry<String, Integer> e : vocabMap.entrySet()) { | ||||||
|                     arr[e.getValue()] = e.getKey(); |                     arr[e.getValue()] = e.getKey(); | ||||||
|                 } |                 } | ||||||
|                 vocabKeysAsList = Arrays.asList(arr); |                 vocabKeysAsList = Arrays.asList(arr); | ||||||
| @ -276,31 +344,31 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|             int vocabSize = vocabMap.size(); |             int vocabSize = vocabMap.size(); | ||||||
|             INDArray labelArr; |             INDArray labelArr; | ||||||
|             INDArray lMask = Nd4j.zeros(DataType.INT, mbPadded, outLength); |             INDArray lMask = Nd4j.zeros(DataType.INT, mbPadded, outLength); | ||||||
|             if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){ |             if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) { | ||||||
|                 labelArr = Nd4j.create(DataType.INT, mbPadded, outLength); |                 labelArr = Nd4j.create(DataType.INT, mbPadded, outLength); | ||||||
|             } else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){ |             } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) { | ||||||
|                 labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength); |                 labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength); | ||||||
|             } else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){ |             } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) { | ||||||
|                 labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize); |                 labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize); | ||||||
|             } else { |             } else { | ||||||
|                 throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat); |                 throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             for( int i=0; i<mb; i++ ){ |             for (int i = 0; i < mbSize; i++) { | ||||||
|                 List<String> tokens = tokenizedSentences.get(i).getFirst(); |                 List<String> tokens = tokenizedSentences.get(i).getFirst(); | ||||||
|                 Pair<List<String>,boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList); |                 Pair<List<String>, boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList); | ||||||
|                 List<String> maskedTokens = p.getFirst(); |                 List<String> maskedTokens = p.getFirst(); | ||||||
|                 boolean[] predictionTarget = p.getSecond(); |                 boolean[] predictionTarget = p.getSecond(); | ||||||
|                 int seqLen = Math.min(predictionTarget.length, outLength); |                 int seqLen = Math.min(predictionTarget.length, outLength); | ||||||
|                 for(int j=0; j<seqLen; j++ ){ |                 for (int j = 0; j < seqLen; j++) { | ||||||
|                     if(predictionTarget[j]){ |                     if (predictionTarget[j]) { | ||||||
|                         String oldToken = tokenizedSentences.get(i).getFirst().get(j);  //This is target |                         String oldToken = tokenizedSentences.get(i).getFirst().get(j);  //This is target | ||||||
|                         int targetTokenIdx = vocabMap.get(oldToken); |                         int targetTokenIdx = vocabMap.get(oldToken); | ||||||
|                         if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){ |                         if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) { | ||||||
|                             labelArr.putScalar(i, j, targetTokenIdx); |                             labelArr.putScalar(i, j, targetTokenIdx); | ||||||
|                         } else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){ |                         } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) { | ||||||
|                             labelArr.putScalar(i, j, targetTokenIdx, 1.0); |                             labelArr.putScalar(i, j, targetTokenIdx, 1.0); | ||||||
|                         } else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){ |                         } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) { | ||||||
|                             labelArr.putScalar(j, i, targetTokenIdx, 1.0); |                             labelArr.putScalar(j, i, targetTokenIdx, 1.0); | ||||||
|                         } |                         } | ||||||
| 
 | 
 | ||||||
| @ -309,7 +377,8 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|                         //Also update previously created feature label indexes: |                         //Also update previously created feature label indexes: | ||||||
|                         String newToken = maskedTokens.get(j); |                         String newToken = maskedTokens.get(j); | ||||||
|                         int newTokenIdx = vocabMap.get(newToken); |                         int newTokenIdx = vocabMap.get(newToken); | ||||||
|                         outIdxsArr.putScalar(i,j,newTokenIdx); |                         //first element of features is outIdxsArr | ||||||
|  |                         featureArray[0].putScalar(i, j, newTokenIdx); | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @ -319,19 +388,14 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|         } else { |         } else { | ||||||
|             throw new IllegalStateException("Task not yet implemented: " + task); |             throw new IllegalStateException("Task not yet implemented: " + task); | ||||||
|         } |         } | ||||||
| 
 |         return new Pair<>(l, lm); | ||||||
|         org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(f, l, fm, lm); |  | ||||||
|         if(preProcessor != null) |  | ||||||
|             preProcessor.preProcess(mds); |  | ||||||
| 
 |  | ||||||
|         return mds; |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     private List<String> tokenizeSentence(String sentence) { |     private List<String> tokenizeSentence(String sentence) { | ||||||
|         Tokenizer t = tokenizerFactory.create(sentence); |         Tokenizer t = tokenizerFactory.create(sentence); | ||||||
| 
 | 
 | ||||||
|         List<String> tokens = new ArrayList<>(); |         List<String> tokens = new ArrayList<>(); | ||||||
|         if(prependToken != null) |         if (prependToken != null) | ||||||
|             tokens.add(prependToken); |             tokens.add(prependToken); | ||||||
| 
 | 
 | ||||||
|         while (t.hasMoreTokens()) { |         while (t.hasMoreTokens()) { | ||||||
| @ -341,6 +405,16 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|         return tokens; |         return tokens; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  |     private List<Pair<String, String>> addDummyLabel(List<String> listOnlySentences) { | ||||||
|  |         List<Pair<String, String>> list = new ArrayList<>(listOnlySentences.size()); | ||||||
|  |         for (String s : listOnlySentences) { | ||||||
|  |             list.add(new Pair<String, String>(s, null)); | ||||||
|  |         } | ||||||
|  |         return list; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|     @Override |     @Override | ||||||
|     public boolean resetSupported() { |     public boolean resetSupported() { | ||||||
|         return true; |         return true; | ||||||
| @ -353,12 +427,12 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public void reset() { |     public void reset() { | ||||||
|         if(sentenceProvider != null){ |         if (sentenceProvider != null) { | ||||||
|             sentenceProvider.reset(); |             sentenceProvider.reset(); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public static Builder builder(){ |     public static Builder builder() { | ||||||
|         return new Builder(); |         return new Builder(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -373,7 +447,7 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|         protected MultiDataSetPreProcessor preProcessor; |         protected MultiDataSetPreProcessor preProcessor; | ||||||
|         protected LabeledSentenceProvider sentenceProvider = null; |         protected LabeledSentenceProvider sentenceProvider = null; | ||||||
|         protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID; |         protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID; | ||||||
|         protected Map<String,Integer> vocabMap;   //TODO maybe use Eclipse ObjectIntHashMap for fewer objects? |         protected Map<String, Integer> vocabMap;   //TODO maybe use Eclipse ObjectIntHashMap for fewer objects? | ||||||
|         protected BertSequenceMasker masker = new BertMaskedLMMasker(); |         protected BertSequenceMasker masker = new BertMaskedLMMasker(); | ||||||
|         protected UnsupervisedLabelFormat unsupervisedLabelFormat; |         protected UnsupervisedLabelFormat unsupervisedLabelFormat; | ||||||
|         protected String maskToken; |         protected String maskToken; | ||||||
| @ -382,7 +456,7 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|         /** |         /** | ||||||
|          * Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details. |          * Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details. | ||||||
|          */ |          */ | ||||||
|         public Builder task(Task task){ |         public Builder task(Task task) { | ||||||
|             this.task = task; |             this.task = task; | ||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
| @ -392,18 +466,19 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|          * For BERT, typically {@link org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory} |          * For BERT, typically {@link org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory} | ||||||
|          * is used |          * is used | ||||||
|          */ |          */ | ||||||
|         public Builder tokenizer(TokenizerFactory tokenizerFactory){ |         public Builder tokenizer(TokenizerFactory tokenizerFactory) { | ||||||
|             this.tokenizerFactory = tokenizerFactory; |             this.tokenizerFactory = tokenizerFactory; | ||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         /** |         /** | ||||||
|          * Specifies how the sequence length of the output data should be handled. See {@link BertIterator} for more details. |          * Specifies how the sequence length of the output data should be handled. See {@link BertIterator} for more details. | ||||||
|          * @param lengthHandling    Length handling |          * | ||||||
|          * @param maxLength         Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH} |          * @param lengthHandling Length handling | ||||||
|  |          * @param maxLength      Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH} | ||||||
|          * @return |          * @return | ||||||
|          */ |          */ | ||||||
|         public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength){ |         public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength) { | ||||||
|             this.lengthHandling = lengthHandling; |             this.lengthHandling = lengthHandling; | ||||||
|             this.maxTokens = maxLength; |             this.maxTokens = maxLength; | ||||||
|             return this; |             return this; | ||||||
| @ -412,9 +487,10 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|         /** |         /** | ||||||
|          * Minibatch size to use (number of examples to train on for each iteration) |          * Minibatch size to use (number of examples to train on for each iteration) | ||||||
|          * See also: {@link #padMinibatches} |          * See also: {@link #padMinibatches} | ||||||
|          * @param minibatchSize    Minibatch size |          * | ||||||
|  |          * @param minibatchSize Minibatch size | ||||||
|          */ |          */ | ||||||
|         public Builder minibatchSize(int minibatchSize){ |         public Builder minibatchSize(int minibatchSize) { | ||||||
|             this.minibatchSize = minibatchSize; |             this.minibatchSize = minibatchSize; | ||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
| @ -429,7 +505,7 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|          * Both options should result in exactly the same model. However, some BERT implementations may require exactly an |          * Both options should result in exactly the same model. However, some BERT implementations may require exactly an | ||||||
|          * exact number of examples in all minibatches to function. |          * exact number of examples in all minibatches to function. | ||||||
|          */ |          */ | ||||||
|         public Builder padMinibatches(boolean padMinibatches){ |         public Builder padMinibatches(boolean padMinibatches) { | ||||||
|             this.padMinibatches = padMinibatches; |             this.padMinibatches = padMinibatches; | ||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
| @ -437,7 +513,7 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|         /** |         /** | ||||||
|          * Set the preprocessor to be used on the MultiDataSets before returning them. Default: none (null) |          * Set the preprocessor to be used on the MultiDataSets before returning them. Default: none (null) | ||||||
|          */ |          */ | ||||||
|         public Builder preProcessor(MultiDataSetPreProcessor preProcessor){ |         public Builder preProcessor(MultiDataSetPreProcessor preProcessor) { | ||||||
|             this.preProcessor = preProcessor; |             this.preProcessor = preProcessor; | ||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
| @ -446,7 +522,7 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|          * Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised |          * Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised | ||||||
|          * use case, the labels will be ignored. |          * use case, the labels will be ignored. | ||||||
|          */ |          */ | ||||||
|         public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider){ |         public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) { | ||||||
|             this.sentenceProvider = sentenceProvider; |             this.sentenceProvider = sentenceProvider; | ||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
| @ -454,7 +530,7 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|         /** |         /** | ||||||
|          * Specify what arrays should be returned. See {@link BertIterator} for more details. |          * Specify what arrays should be returned. See {@link BertIterator} for more details. | ||||||
|          */ |          */ | ||||||
|         public Builder featureArrays(FeatureArrays featureArrays){ |         public Builder featureArrays(FeatureArrays featureArrays) { | ||||||
|             this.featureArrays = featureArrays; |             this.featureArrays = featureArrays; | ||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
| @ -465,7 +541,7 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|          * If using {@link BertWordPieceTokenizerFactory}, |          * If using {@link BertWordPieceTokenizerFactory}, | ||||||
|          * this can be obtained using {@link BertWordPieceTokenizerFactory#getVocab()} |          * this can be obtained using {@link BertWordPieceTokenizerFactory#getVocab()} | ||||||
|          */ |          */ | ||||||
|         public Builder vocabMap(Map<String,Integer> vocabMap){ |         public Builder vocabMap(Map<String, Integer> vocabMap) { | ||||||
|             this.vocabMap = vocabMap; |             this.vocabMap = vocabMap; | ||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
| @ -475,7 +551,7 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|          * masked language model. This can be used to customize how the masking is performed.<br> |          * masked language model. This can be used to customize how the masking is performed.<br> | ||||||
|          * Default: {@link BertMaskedLMMasker} |          * Default: {@link BertMaskedLMMasker} | ||||||
|          */ |          */ | ||||||
|         public Builder masker(BertSequenceMasker masker){ |         public Builder masker(BertSequenceMasker masker) { | ||||||
|             this.masker = masker; |             this.masker = masker; | ||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
| @ -485,7 +561,7 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|          * masked language model. Used to specify the format that the labels should be returned in. |          * masked language model. Used to specify the format that the labels should be returned in. | ||||||
|          * See {@link BertIterator} for more details. |          * See {@link BertIterator} for more details. | ||||||
|          */ |          */ | ||||||
|         public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat){ |         public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat) { | ||||||
|             this.unsupervisedLabelFormat = labelFormat; |             this.unsupervisedLabelFormat = labelFormat; | ||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
| @ -497,7 +573,7 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|          * the exact behaviour will depend on what masker is used.<br> |          * the exact behaviour will depend on what masker is used.<br> | ||||||
|          * Note that this must be in the vocabulary map set in {@link #vocabMap} |          * Note that this must be in the vocabulary map set in {@link #vocabMap} | ||||||
|          */ |          */ | ||||||
|         public Builder maskToken(String maskToken){ |         public Builder maskToken(String maskToken) { | ||||||
|             this.maskToken = maskToken; |             this.maskToken = maskToken; | ||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
| @ -510,12 +586,12 @@ public class BertIterator implements MultiDataSetIterator { | |||||||
|          * |          * | ||||||
|          * @param prependToken The token to start each sequence with (null: no token will be prepended) |          * @param prependToken The token to start each sequence with (null: no token will be prepended) | ||||||
|          */ |          */ | ||||||
|         public Builder prependToken(String prependToken){ |         public Builder prependToken(String prependToken) { | ||||||
|             this.prependToken = prependToken; |             this.prependToken = prependToken; | ||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         public BertIterator build(){ |         public BertIterator build() { | ||||||
|             Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed"); |             Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed"); | ||||||
|             Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required"); |             Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required"); | ||||||
|             Preconditions.checkState(vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map<String,Integer>) to set"); |             Preconditions.checkState(vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map<String,Integer>) to set"); | ||||||
|  | |||||||
| @ -26,7 +26,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; | |||||||
| import org.nd4j.linalg.dataset.api.MultiDataSet; | import org.nd4j.linalg.dataset.api.MultiDataSet; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| import org.nd4j.linalg.indexing.NDArrayIndex; | import org.nd4j.linalg.indexing.NDArrayIndex; | ||||||
| import org.nd4j.linalg.io.ClassPathResource; |  | ||||||
| import org.nd4j.linalg.primitives.Pair; | import org.nd4j.linalg.primitives.Pair; | ||||||
| import org.nd4j.resources.Resources; | import org.nd4j.resources.Resources; | ||||||
| 
 | 
 | ||||||
| @ -34,10 +33,7 @@ import java.io.File; | |||||||
| import java.io.IOException; | import java.io.IOException; | ||||||
| import java.nio.charset.Charset; | import java.nio.charset.Charset; | ||||||
| import java.nio.charset.StandardCharsets; | import java.nio.charset.StandardCharsets; | ||||||
| import java.util.Arrays; | import java.util.*; | ||||||
| import java.util.List; |  | ||||||
| import java.util.Map; |  | ||||||
| import java.util.Random; |  | ||||||
| 
 | 
 | ||||||
| import static org.junit.Assert.*; | import static org.junit.Assert.*; | ||||||
| 
 | 
 | ||||||
| @ -54,6 +50,9 @@ public class TestBertIterator extends BaseDL4JTest { | |||||||
| 
 | 
 | ||||||
|         String toTokenize1 = "I saw a girl with a telescope."; |         String toTokenize1 = "I saw a girl with a telescope."; | ||||||
|         String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; |         String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; | ||||||
|  |         List<String> forInference = new ArrayList<>(); | ||||||
|  |         forInference.add(toTokenize1); | ||||||
|  |         forInference.add(toTokenize2); | ||||||
|         BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); |         BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); | ||||||
| 
 | 
 | ||||||
|         BertIterator b = BertIterator.builder() |         BertIterator b = BertIterator.builder() | ||||||
| @ -100,12 +99,15 @@ public class TestBertIterator extends BaseDL4JTest { | |||||||
| 
 | 
 | ||||||
|         assertEquals(expF, mds.getFeatures(0)); |         assertEquals(expF, mds.getFeatures(0)); | ||||||
|         assertEquals(expM, mds.getFeaturesMaskArray(0)); |         assertEquals(expM, mds.getFeaturesMaskArray(0)); | ||||||
|  |         assertEquals(expF,b.featurizeSentences(forInference).getFirst()[0]); | ||||||
|  |         assertEquals(expM,b.featurizeSentences(forInference).getSecond()[0]); | ||||||
| 
 | 
 | ||||||
|         assertFalse(b.hasNext()); |         assertFalse(b.hasNext()); | ||||||
|         b.reset(); |         b.reset(); | ||||||
|         assertTrue(b.hasNext()); |         assertTrue(b.hasNext()); | ||||||
|         MultiDataSet mds2 = b.next(); |         MultiDataSet mds2 = b.next(); | ||||||
| 
 | 
 | ||||||
|  |         forInference.set(0,toTokenize2); | ||||||
|         //Same thing, but with segment ID also |         //Same thing, but with segment ID also | ||||||
|         b = BertIterator.builder() |         b = BertIterator.builder() | ||||||
|                 .tokenizer(t) |                 .tokenizer(t) | ||||||
| @ -118,9 +120,11 @@ public class TestBertIterator extends BaseDL4JTest { | |||||||
|                 .build(); |                 .build(); | ||||||
|         mds = b.next(); |         mds = b.next(); | ||||||
|         assertEquals(2, mds.getFeatures().length); |         assertEquals(2, mds.getFeatures().length); | ||||||
|  |         assertEquals(2,b.featurizeSentences(forInference).getFirst().length); | ||||||
|         //Segment ID should be all 0s for single segment task |         //Segment ID should be all 0s for single segment task | ||||||
|         INDArray segmentId = expM.like(); |         INDArray segmentId = expM.like(); | ||||||
|         assertEquals(segmentId, mds.getFeatures(1)); |         assertEquals(segmentId, mds.getFeatures(1)); | ||||||
|  |         assertEquals(segmentId,b.featurizeSentences(forInference).getFirst()[1]); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Test(timeout = 20000L) |     @Test(timeout = 20000L) | ||||||
| @ -157,6 +161,9 @@ public class TestBertIterator extends BaseDL4JTest { | |||||||
|     public void testLengthHandling() throws Exception { |     public void testLengthHandling() throws Exception { | ||||||
|         String toTokenize1 = "I saw a girl with a telescope."; |         String toTokenize1 = "I saw a girl with a telescope."; | ||||||
|         String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; |         String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; | ||||||
|  |         List<String> forInference = new ArrayList<>(); | ||||||
|  |         forInference.add(toTokenize1); | ||||||
|  |         forInference.add(toTokenize2); | ||||||
|         BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); |         BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); | ||||||
|         INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); |         INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); | ||||||
|         INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); |         INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); | ||||||
| @ -205,6 +212,8 @@ public class TestBertIterator extends BaseDL4JTest { | |||||||
|         assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); |         assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); | ||||||
|         assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeatures(0)); |         assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeatures(0)); | ||||||
|         assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeaturesMaskArray(0)); |         assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeaturesMaskArray(0)); | ||||||
|  |         assertEquals(mds.getFeatures(0),b.featurizeSentences(forInference).getFirst()[0]); | ||||||
|  |         assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]); | ||||||
| 
 | 
 | ||||||
|         //Clip only: clip to maximum, but don't pad if less |         //Clip only: clip to maximum, but don't pad if less | ||||||
|         b = BertIterator.builder() |         b = BertIterator.builder() | ||||||
| @ -227,6 +236,9 @@ public class TestBertIterator extends BaseDL4JTest { | |||||||
|         Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); |         Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); | ||||||
|         String toTokenize1 = "I saw a girl with a telescope."; |         String toTokenize1 = "I saw a girl with a telescope."; | ||||||
|         String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; |         String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; | ||||||
|  |         List<String> forInference = new ArrayList<>(); | ||||||
|  |         forInference.add(toTokenize1); | ||||||
|  |         forInference.add(toTokenize2); | ||||||
|         BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); |         BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); | ||||||
|         INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); |         INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); | ||||||
|         INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); |         INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); | ||||||
| @ -288,6 +300,9 @@ public class TestBertIterator extends BaseDL4JTest { | |||||||
|         assertEquals(expM, mds.getFeaturesMaskArray(0)); |         assertEquals(expM, mds.getFeaturesMaskArray(0)); | ||||||
|         assertEquals(expL, mds.getLabels(0)); |         assertEquals(expL, mds.getLabels(0)); | ||||||
|         assertEquals(expLM, mds.getLabelsMaskArray(0)); |         assertEquals(expLM, mds.getLabelsMaskArray(0)); | ||||||
|  | 
 | ||||||
|  |         assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]); | ||||||
|  |         assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     private static class TestSentenceProvider implements LabeledSentenceProvider { |     private static class TestSentenceProvider implements LabeledSentenceProvider { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user