Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARKNLP-921: Bug Fix for BPE and RobertaForQA #14022

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private[johnsnowlabs] class Bart(
with Generate {

val bpeTokenizer: BartTokenizer = BpeTokenizer
.forModel("bart", merges = merges, vocab = vocabulary, padWithSentenceTokens = false)
.forModel("bart", merges = merges, vocab = vocabulary, padWithSequenceTokens = false)
.asInstanceOf[BartTokenizer]
private val _tfBartSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
Expand Down
149 changes: 139 additions & 10 deletions src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ package com.johnsnowlabs.ml.ai

import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder}
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation}
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation, AnnotatorType}
import org.tensorflow.ndarray.buffer.IntDataBuffer

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -63,7 +64,8 @@ private[johnsnowlabs] class RoBertaClassification(
maxSeqLength: Int,
caseSensitive: Boolean): Seq[WordpieceTokenizedSentence] = {

val bpeTokenizer = BpeTokenizer.forModel("roberta", merges, vocabulary)
val bpeTokenizer =
BpeTokenizer.forModel("roberta", merges, vocabulary, alwaysAddPrefix = false)

sentences.map { tokenIndex =>
// filter empty and only whitespace tokens
Expand Down Expand Up @@ -106,7 +108,8 @@ private[johnsnowlabs] class RoBertaClassification(
caseSensitive: Boolean): Seq[WordpieceTokenizedSentence] = {
// we need the original form of the token
// let's lowercase if needed right before the encoding
val bpeTokenizer = BpeTokenizer.forModel("roberta", merges, vocabulary)
val bpeTokenizer =
BpeTokenizer.forModel("roberta", merges, vocabulary, alwaysAddPrefix = false)
val sentences = docs.map { s => Sentence(s.result, s.begin, s.end, 0) }

sentences.map { sentence =>
Expand All @@ -115,12 +118,10 @@ private[johnsnowlabs] class RoBertaClassification(
val sentenceEnd = sentence.end
val sentenceIndex = sentence.index

// TODO: we should implement dedicated the tokenize and tokenizeSubText methods for full a sentence rather than token by token
val indexedTokens =
bpeTokenizer.tokenize(Sentence(content, sentenceBegin, sentenceEnd, sentenceIndex))

val wordpieceTokens =
indexedTokens.flatMap(token => bpeTokenizer.encode(token)).take(maxSeqLength)
val wordpieceTokens = bpeTokenizer.encode(indexedTokens).take(maxSeqLength)

WordpieceTokenizedSentence(wordpieceTokens)
}
Expand Down Expand Up @@ -372,12 +373,10 @@ private[johnsnowlabs] class RoBertaClassification(
tensors.clearTensors()

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray
val endScores: Array[Array[Float]] = endLogits.grouped(endDim).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray
val startScores: Array[Array[Float]] = startLogits.grouped(startDim).toArray

(startScores, endScores)
}
Expand All @@ -389,4 +388,134 @@ private[johnsnowlabs] class RoBertaClassification(
tokenizedSentences(sentence._2).indexedTokens.find(p => p.begin == tokenPiece.begin)
}

/** Encodes two sequences to be compatible with the RoBerta models.
*
* Unlike other models, ReBerta requires two eos tokens to join two sequences.
*
* For example, the pair of sequences A, B should be joined to: `<s> A </s></s> B </s>`
*/
override def encodeSequence(
seq1: Seq[WordpieceTokenizedSentence],
seq2: Seq[WordpieceTokenizedSentence],
maxSequenceLength: Int): Seq[Array[Int]] = {

val question = seq1
.flatMap { wpTokSentence =>
wpTokSentence.tokens.map(t => t.pieceId)
}
.toArray
.take(maxSequenceLength - 2) ++ Array(sentenceEndTokenId, sentenceEndTokenId)

val context = seq2
.flatMap { wpTokSentence =>
wpTokSentence.tokens.map(t => t.pieceId)
}
.toArray
.take(maxSequenceLength - question.length - 2) ++ Array(sentenceEndTokenId)

Seq(Array(sentenceStartTokenId) ++ question ++ context)
}

/** Calculates the normalized softmax probabilities.
*
* @param scores
* Raw logits
* @return
* Normalized softmax probabilities
*/
private def normalizedSoftmax(scores: Array[Float]): Array[Float] = {
val max = scores.max
calculateSoftmax(scores.map(_ - max))
}

override def predictSpan(
documents: Seq[Annotation],
maxSentenceLength: Int,
caseSensitive: Boolean,
mergeTokenStrategy: String = MergeTokenStrategy.vocab,
engine: String = TensorFlow.name): Seq[Annotation] = {

val questionAnnot = Seq(documents.head)
val contextAnnot = documents.drop(1)

val wordPieceTokenizedQuestion =
tokenizeDocument(questionAnnot, maxSentenceLength, caseSensitive)
val wordPieceTokenizedContext =
tokenizeDocument(contextAnnot, maxSentenceLength, caseSensitive)
val questionLength = wordPieceTokenizedQuestion.head.tokens.length

val encodedInput =
encodeSequence(wordPieceTokenizedQuestion, wordPieceTokenizedContext, maxSentenceLength)
val (startLogits, endLogits) = tagSpan(encodedInput)

/** Sets log-logits to (almost) 0 for question and padding tokens so they can't contribute to
* the final softmax score.
*
* @param scores
* Logits of the combined sequences
* @return
* Scores, with unwanted tokens set to log-probability 0
*/
def maskUndesiredTokens(scores: Array[Float]): Array[Float] = {
scores.zipWithIndex.map { case (score, i) =>
// 3 added special tokens in encoded sequence (1 bos, 2 eos)
if ((i > 0 && i < questionLength + 3) || i == encodedInput.head.length - 1)
-10000.0f
else score
}
}

val processedStartLogits = startLogits.map { scores =>
normalizedSoftmax(maskUndesiredTokens(scores))
}
val processedEndLogits = endLogits.map { scores =>
normalizedSoftmax(maskUndesiredTokens(scores))
}

val startScores = processedStartLogits.transpose.map(_.sum / startLogits.length)
val endScores = processedEndLogits.transpose.map(_.sum / endLogits.length)

// Drop BOS token from valid results
val startIndex = startScores.zipWithIndex.drop(1).maxBy(_._1)
val endIndex = endScores.zipWithIndex.drop(1).maxBy(_._1)

val offsetStartIndex = 3 // 3 added special tokens
val offsetEndIndex = offsetStartIndex - 1

val allTokenPieces =
wordPieceTokenizedQuestion.head.tokens ++ wordPieceTokenizedContext.flatMap(x => x.tokens)
val decodedAnswer =
allTokenPieces.slice(startIndex._2 - offsetStartIndex, endIndex._2 - offsetEndIndex)
val content =
mergeTokenStrategy match {
case MergeTokenStrategy.vocab =>
decodedAnswer.filter(_.isWordStart).map(x => x.token).mkString(" ")
case MergeTokenStrategy.sentencePiece =>
val token = ""
decodedAnswer
.map(x =>
if (x.isWordStart) " " + token + x.token
else token + x.token)
.mkString("")
.trim
}

val totalScore = startIndex._1 * endIndex._1
Seq(
Annotation(
annotatorType = AnnotatorType.CHUNK,
begin = 0,
end = if (content.isEmpty) 0 else content.length - 1,
result = content,
metadata = Map(
"sentence" -> "0",
"chunk" -> "0",
"start" -> decodedAnswer.head.begin.toString,
"start_score" -> startIndex._1.toString,
"end" -> decodedAnswer.last.end.toString,
"end_score" -> endIndex._1.toString,
"score" -> totalScore.toString)))

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,7 @@ class GPT2Transformer(override val uid: String)
.forModel(
"gpt2",
merges = $$(merges),
vocab = $$(vocabulary),
padWithSentenceTokens = false)
vocab = $$(vocabulary))
.asInstanceOf[Gpt2Tokenizer]

_tfModel = Some(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ class BartTokenizer(
merges: Map[(String, String), Int],
vocab: Map[String, Int],
specialTokens: SpecialTokens,
padWithSentenceTokens: Boolean = false,
addPrefixSpace: Boolean = false)
padWithSequenceTokens: Boolean = false,
addPrefixSpaceToSentence: Boolean = false)
extends Gpt2Tokenizer(
merges,
vocab,
specialTokens,
padWithSentenceTokens,
padWithSequenceTokens,
prependString = "Ġ",
addPrefixSpace)
addPrefixSpaceToSentence)
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,30 @@ import scala.collection.mutable
import scala.collection.mutable.ListBuffer

/** A BPE Tokenizer based on GPT2's tokenization scheme. The tokenization can then be used for
* models based on this scheme (e.g. GPT2, roBERTa, DeBERTa) TODO: truncation assumed?
* models based on this scheme (e.g. GPT2, roBERTa, DeBERTa)
*
* TODO: truncation assumed?
*
* @param merges
* Map of tokens that are mergeable
* @param vocab
* Map of tokens to encoded representation
* @param specialTokens
* Collection of special tokens
* @param padWithSentenceTokens
* @param padWithSequenceTokens
* Whether to pad the sentence with sentence tokens at the start and end
* @param addPrefixSpace
* @param addPrefixSpaceToSentence
* Whether to add a space to the first word of a sentence
* @param alwaysAddPrefix
* Whether to always prefix token ids with `prefixForPieceId`
*/
private[nlp] abstract class BpeTokenizer(
val merges: Map[(String, String), Int],
val vocab: Map[String, Int],
val specialTokens: SpecialTokens,
val padWithSentenceTokens: Boolean,
val addPrefixSpace: Boolean) {
val padWithSequenceTokens: Boolean,
val addPrefixSpaceToSentence: Boolean,
val alwaysAddPrefix: Boolean) {

protected val bpeRanks: Map[(String, String), Int] = {
merges
Expand All @@ -60,8 +66,8 @@ private[nlp] abstract class BpeTokenizer(
}

// Can be overridden in inherited class
protected val prependForPieceId: Option[String] = None
protected val appendForPieceId: Option[String] = None
protected val prefixForPieceId: Option[String] = None
protected val suffixForPieceId: Option[String] = None

protected def performMerges(
wordChars: Array[String],
Expand Down Expand Up @@ -122,16 +128,16 @@ private[nlp] abstract class BpeTokenizer(
val isWordStart = indToken.begin == indexes._1
val isDocumentStart = indToken.begin == 0
var processedSubWord = subWord
processedSubWord = if (isDocumentStart && !addPrefixSpace) {
processedSubWord = if (isDocumentStart && !addPrefixSpaceToSentence) {
processedSubWord
} else
prependForPieceId match {
case None => processedSubWord
case Some(prepend) =>
prefixForPieceId match {
case Some(prepend) if alwaysAddPrefix =>
if (isWordStart && subWord.indexOf(prepend) < 0) prepend + processedSubWord
else processedSubWord
case _ => processedSubWord
}
processedSubWord = appendForPieceId match {
processedSubWord = suffixForPieceId match {
case None => processedSubWord
case Some(append) =>
val isWordEnd = indToken.end == indexes._2
Expand Down Expand Up @@ -239,7 +245,7 @@ private[nlp] abstract class BpeTokenizer(
}

/** Needs to be implemented */
def tokenizeSubText(text: String, indexOffset: Int): Array[IndexedToken]
protected def tokenizeSubText(text: String, indexOffset: Int): Array[IndexedToken]

/** Special tokens of the model for processing */
val sentencePadding: (String, String) =
Expand All @@ -264,7 +270,7 @@ private[nlp] abstract class BpeTokenizer(
textList = splitTexts.clone()
}

if (padWithSentenceTokens) {
if (padWithSequenceTokens) {
text = sentencePadding._1 + text + sentencePadding._2
splitTexts.prepend(sentencePadding._1)
splitTexts.append(sentencePadding._2)
Expand Down Expand Up @@ -310,9 +316,10 @@ object BpeTokenizer {
modelType: String,
merges: Map[(String, String), Int],
vocab: Map[String, Int],
padWithSentenceTokens: Boolean = false,
addPrefixSpace: Boolean = false,
specialTokens: Option[SpecialTokens] = None): BpeTokenizer = {
padWithSequenceTokens: Boolean = false,
addPrefixSpaceToSentence: Boolean = false,
specialTokens: Option[SpecialTokens] = None,
alwaysAddPrefix: Boolean = true): BpeTokenizer = {

def modelSpecialTokens() = specialTokens match {
case Some(specialTok) => specialTok
Expand All @@ -325,24 +332,26 @@ object BpeTokenizer {
merges,
vocab,
modelSpecialTokens(),
padWithSentenceTokens,
addPrefixSpace = addPrefixSpace)
padWithSequenceTokens,
addPrefixSpaceToSentence = addPrefixSpaceToSentence,
alwaysAddPrefix = alwaysAddPrefix)
case "xlm" =>
new XlmTokenizer(merges, vocab, modelSpecialTokens(), padWithSentenceTokens)
new XlmTokenizer(merges, vocab, modelSpecialTokens(), padWithSequenceTokens)
case "gpt2" =>
new Gpt2Tokenizer(
merges,
vocab,
modelSpecialTokens(),
padWithSentenceTokens,
addPrefixSpace = addPrefixSpace)
padWithSequenceTokens,
addPrefixSpaceToSentence = addPrefixSpaceToSentence,
alwaysAddPrefix = alwaysAddPrefix)
case "bart" =>
new BartTokenizer(
merges,
vocab,
modelSpecialTokens(),
padWithSentenceTokens,
addPrefixSpace = addPrefixSpace)
padWithSequenceTokens,
addPrefixSpaceToSentence = addPrefixSpaceToSentence)
case _ =>
throw new IllegalArgumentException("Model type \"" + modelType + "\" not supported yet.")
}
Expand Down
Loading
Loading