Skip to content

Commit

Permalink
Merge pull request #14022 from DevinTDHa/bug/SPARKNLP-921-RobertaForQA
Browse files Browse the repository at this point in the history
SPARKNLP-921: Bug Fix for BPE and RobertaForQA
  • Loading branch information
maziyarpanahi authored Oct 26, 2023
2 parents d893ee1 + 31b1007 commit 9baff1d
Show file tree
Hide file tree
Showing 14 changed files with 283 additions and 64 deletions.
2 changes: 1 addition & 1 deletion src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private[johnsnowlabs] class Bart(
with Generate {

val bpeTokenizer: BartTokenizer = BpeTokenizer
.forModel("bart", merges = merges, vocab = vocabulary, padWithSentenceTokens = false)
.forModel("bart", merges = merges, vocab = vocabulary, padWithSequenceTokens = false)
.asInstanceOf[BartTokenizer]
private val _tfBartSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
Expand Down
149 changes: 139 additions & 10 deletions src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ package com.johnsnowlabs.ml.ai

import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder}
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation}
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation, AnnotatorType}
import org.tensorflow.ndarray.buffer.IntDataBuffer

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -63,7 +64,8 @@ private[johnsnowlabs] class RoBertaClassification(
maxSeqLength: Int,
caseSensitive: Boolean): Seq[WordpieceTokenizedSentence] = {

val bpeTokenizer = BpeTokenizer.forModel("roberta", merges, vocabulary)
val bpeTokenizer =
BpeTokenizer.forModel("roberta", merges, vocabulary, alwaysAddPrefix = false)

sentences.map { tokenIndex =>
// filter empty and only whitespace tokens
Expand Down Expand Up @@ -106,7 +108,8 @@ private[johnsnowlabs] class RoBertaClassification(
caseSensitive: Boolean): Seq[WordpieceTokenizedSentence] = {
// we need the original form of the token
// let's lowercase if needed right before the encoding
val bpeTokenizer = BpeTokenizer.forModel("roberta", merges, vocabulary)
val bpeTokenizer =
BpeTokenizer.forModel("roberta", merges, vocabulary, alwaysAddPrefix = false)
val sentences = docs.map { s => Sentence(s.result, s.begin, s.end, 0) }

sentences.map { sentence =>
Expand All @@ -115,12 +118,10 @@ private[johnsnowlabs] class RoBertaClassification(
val sentenceEnd = sentence.end
val sentenceIndex = sentence.index

// TODO: we should implement dedicated the tokenize and tokenizeSubText methods for full a sentence rather than token by token
val indexedTokens =
bpeTokenizer.tokenize(Sentence(content, sentenceBegin, sentenceEnd, sentenceIndex))

val wordpieceTokens =
indexedTokens.flatMap(token => bpeTokenizer.encode(token)).take(maxSeqLength)
val wordpieceTokens = bpeTokenizer.encode(indexedTokens).take(maxSeqLength)

WordpieceTokenizedSentence(wordpieceTokens)
}
Expand Down Expand Up @@ -372,12 +373,10 @@ private[johnsnowlabs] class RoBertaClassification(
tensors.clearTensors()

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray
val endScores: Array[Array[Float]] = endLogits.grouped(endDim).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray
val startScores: Array[Array[Float]] = startLogits.grouped(startDim).toArray

(startScores, endScores)
}
Expand All @@ -389,4 +388,134 @@ private[johnsnowlabs] class RoBertaClassification(
tokenizedSentences(sentence._2).indexedTokens.find(p => p.begin == tokenPiece.begin)
}

/** Encodes two sequences to be compatible with the RoBerta models.
*
* Unlike other models, ReBerta requires two eos tokens to join two sequences.
*
* For example, the pair of sequences A, B should be joined to: `<s> A </s></s> B </s>`
*/
override def encodeSequence(
seq1: Seq[WordpieceTokenizedSentence],
seq2: Seq[WordpieceTokenizedSentence],
maxSequenceLength: Int): Seq[Array[Int]] = {

val question = seq1
.flatMap { wpTokSentence =>
wpTokSentence.tokens.map(t => t.pieceId)
}
.toArray
.take(maxSequenceLength - 2) ++ Array(sentenceEndTokenId, sentenceEndTokenId)

val context = seq2
.flatMap { wpTokSentence =>
wpTokSentence.tokens.map(t => t.pieceId)
}
.toArray
.take(maxSequenceLength - question.length - 2) ++ Array(sentenceEndTokenId)

Seq(Array(sentenceStartTokenId) ++ question ++ context)
}

/** Calculates the normalized softmax probabilities.
*
* @param scores
* Raw logits
* @return
* Normalized softmax probabilities
*/
private def normalizedSoftmax(scores: Array[Float]): Array[Float] = {
val max = scores.max
calculateSoftmax(scores.map(_ - max))
}

override def predictSpan(
documents: Seq[Annotation],
maxSentenceLength: Int,
caseSensitive: Boolean,
mergeTokenStrategy: String = MergeTokenStrategy.vocab,
engine: String = TensorFlow.name): Seq[Annotation] = {

val questionAnnot = Seq(documents.head)
val contextAnnot = documents.drop(1)

val wordPieceTokenizedQuestion =
tokenizeDocument(questionAnnot, maxSentenceLength, caseSensitive)
val wordPieceTokenizedContext =
tokenizeDocument(contextAnnot, maxSentenceLength, caseSensitive)
val questionLength = wordPieceTokenizedQuestion.head.tokens.length

val encodedInput =
encodeSequence(wordPieceTokenizedQuestion, wordPieceTokenizedContext, maxSentenceLength)
val (startLogits, endLogits) = tagSpan(encodedInput)

/** Sets log-logits to (almost) 0 for question and padding tokens so they can't contribute to
* the final softmax score.
*
* @param scores
* Logits of the combined sequences
* @return
* Scores, with unwanted tokens set to log-probability 0
*/
def maskUndesiredTokens(scores: Array[Float]): Array[Float] = {
scores.zipWithIndex.map { case (score, i) =>
// 3 added special tokens in encoded sequence (1 bos, 2 eos)
if ((i > 0 && i < questionLength + 3) || i == encodedInput.head.length - 1)
-10000.0f
else score
}
}

val processedStartLogits = startLogits.map { scores =>
normalizedSoftmax(maskUndesiredTokens(scores))
}
val processedEndLogits = endLogits.map { scores =>
normalizedSoftmax(maskUndesiredTokens(scores))
}

val startScores = processedStartLogits.transpose.map(_.sum / startLogits.length)
val endScores = processedEndLogits.transpose.map(_.sum / endLogits.length)

// Drop BOS token from valid results
val startIndex = startScores.zipWithIndex.drop(1).maxBy(_._1)
val endIndex = endScores.zipWithIndex.drop(1).maxBy(_._1)

val offsetStartIndex = 3 // 3 added special tokens
val offsetEndIndex = offsetStartIndex - 1

val allTokenPieces =
wordPieceTokenizedQuestion.head.tokens ++ wordPieceTokenizedContext.flatMap(x => x.tokens)
val decodedAnswer =
allTokenPieces.slice(startIndex._2 - offsetStartIndex, endIndex._2 - offsetEndIndex)
val content =
mergeTokenStrategy match {
case MergeTokenStrategy.vocab =>
decodedAnswer.filter(_.isWordStart).map(x => x.token).mkString(" ")
case MergeTokenStrategy.sentencePiece =>
val token = ""
decodedAnswer
.map(x =>
if (x.isWordStart) " " + token + x.token
else token + x.token)
.mkString("")
.trim
}

val totalScore = startIndex._1 * endIndex._1
Seq(
Annotation(
annotatorType = AnnotatorType.CHUNK,
begin = 0,
end = if (content.isEmpty) 0 else content.length - 1,
result = content,
metadata = Map(
"sentence" -> "0",
"chunk" -> "0",
"start" -> decodedAnswer.head.begin.toString,
"start_score" -> startIndex._1.toString,
"end" -> decodedAnswer.last.end.toString,
"end_score" -> endIndex._1.toString,
"score" -> totalScore.toString)))

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,7 @@ class GPT2Transformer(override val uid: String)
.forModel(
"gpt2",
merges = $$(merges),
vocab = $$(vocabulary),
padWithSentenceTokens = false)
vocab = $$(vocabulary))
.asInstanceOf[Gpt2Tokenizer]

_tfModel = Some(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ class BartTokenizer(
merges: Map[(String, String), Int],
vocab: Map[String, Int],
specialTokens: SpecialTokens,
padWithSentenceTokens: Boolean = false,
addPrefixSpace: Boolean = false)
padWithSequenceTokens: Boolean = false,
addPrefixSpaceToSentence: Boolean = false)
extends Gpt2Tokenizer(
merges,
vocab,
specialTokens,
padWithSentenceTokens,
padWithSequenceTokens,
prependString = "Ġ",
addPrefixSpace)
addPrefixSpaceToSentence)
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,30 @@ import scala.collection.mutable
import scala.collection.mutable.ListBuffer

/** A BPE Tokenizer based on GPT2's tokenization scheme. The tokenization can then be used for
* models based on this scheme (e.g. GPT2, roBERTa, DeBERTa) TODO: truncation assumed?
* models based on this scheme (e.g. GPT2, roBERTa, DeBERTa)
*
* TODO: truncation assumed?
*
* @param merges
* Map of tokens that are mergeable
* @param vocab
* Map of tokens to encoded representation
* @param specialTokens
* Collection of special tokens
* @param padWithSentenceTokens
* @param padWithSequenceTokens
* Whether to pad the sentence with sentence tokens at the start and end
* @param addPrefixSpace
* @param addPrefixSpaceToSentence
* Whether to add a space to the first word of a sentence
* @param alwaysAddPrefix
* Whether to always prefix token ids with `prefixForPieceId`
*/
private[nlp] abstract class BpeTokenizer(
val merges: Map[(String, String), Int],
val vocab: Map[String, Int],
val specialTokens: SpecialTokens,
val padWithSentenceTokens: Boolean,
val addPrefixSpace: Boolean) {
val padWithSequenceTokens: Boolean,
val addPrefixSpaceToSentence: Boolean,
val alwaysAddPrefix: Boolean) {

protected val bpeRanks: Map[(String, String), Int] = {
merges
Expand All @@ -60,8 +66,8 @@ private[nlp] abstract class BpeTokenizer(
}

// Can be overridden in inherited class
protected val prependForPieceId: Option[String] = None
protected val appendForPieceId: Option[String] = None
protected val prefixForPieceId: Option[String] = None
protected val suffixForPieceId: Option[String] = None

protected def performMerges(
wordChars: Array[String],
Expand Down Expand Up @@ -122,16 +128,16 @@ private[nlp] abstract class BpeTokenizer(
val isWordStart = indToken.begin == indexes._1
val isDocumentStart = indToken.begin == 0
var processedSubWord = subWord
processedSubWord = if (isDocumentStart && !addPrefixSpace) {
processedSubWord = if (isDocumentStart && !addPrefixSpaceToSentence) {
processedSubWord
} else
prependForPieceId match {
case None => processedSubWord
case Some(prepend) =>
prefixForPieceId match {
case Some(prepend) if alwaysAddPrefix =>
if (isWordStart && subWord.indexOf(prepend) < 0) prepend + processedSubWord
else processedSubWord
case _ => processedSubWord
}
processedSubWord = appendForPieceId match {
processedSubWord = suffixForPieceId match {
case None => processedSubWord
case Some(append) =>
val isWordEnd = indToken.end == indexes._2
Expand Down Expand Up @@ -239,7 +245,7 @@ private[nlp] abstract class BpeTokenizer(
}

/** Needs to be implemented */
def tokenizeSubText(text: String, indexOffset: Int): Array[IndexedToken]
protected def tokenizeSubText(text: String, indexOffset: Int): Array[IndexedToken]

/** Special tokens of the model for processing */
val sentencePadding: (String, String) =
Expand All @@ -264,7 +270,7 @@ private[nlp] abstract class BpeTokenizer(
textList = splitTexts.clone()
}

if (padWithSentenceTokens) {
if (padWithSequenceTokens) {
text = sentencePadding._1 + text + sentencePadding._2
splitTexts.prepend(sentencePadding._1)
splitTexts.append(sentencePadding._2)
Expand Down Expand Up @@ -310,9 +316,10 @@ object BpeTokenizer {
modelType: String,
merges: Map[(String, String), Int],
vocab: Map[String, Int],
padWithSentenceTokens: Boolean = false,
addPrefixSpace: Boolean = false,
specialTokens: Option[SpecialTokens] = None): BpeTokenizer = {
padWithSequenceTokens: Boolean = false,
addPrefixSpaceToSentence: Boolean = false,
specialTokens: Option[SpecialTokens] = None,
alwaysAddPrefix: Boolean = true): BpeTokenizer = {

def modelSpecialTokens() = specialTokens match {
case Some(specialTok) => specialTok
Expand All @@ -325,24 +332,26 @@ object BpeTokenizer {
merges,
vocab,
modelSpecialTokens(),
padWithSentenceTokens,
addPrefixSpace = addPrefixSpace)
padWithSequenceTokens,
addPrefixSpaceToSentence = addPrefixSpaceToSentence,
alwaysAddPrefix = alwaysAddPrefix)
case "xlm" =>
new XlmTokenizer(merges, vocab, modelSpecialTokens(), padWithSentenceTokens)
new XlmTokenizer(merges, vocab, modelSpecialTokens(), padWithSequenceTokens)
case "gpt2" =>
new Gpt2Tokenizer(
merges,
vocab,
modelSpecialTokens(),
padWithSentenceTokens,
addPrefixSpace = addPrefixSpace)
padWithSequenceTokens,
addPrefixSpaceToSentence = addPrefixSpaceToSentence,
alwaysAddPrefix = alwaysAddPrefix)
case "bart" =>
new BartTokenizer(
merges,
vocab,
modelSpecialTokens(),
padWithSentenceTokens,
addPrefixSpace = addPrefixSpace)
padWithSequenceTokens,
addPrefixSpaceToSentence = addPrefixSpaceToSentence)
case _ =>
throw new IllegalArgumentException("Model type \"" + modelType + "\" not supported yet.")
}
Expand Down
Loading

0 comments on commit 9baff1d

Please sign in to comment.