Skip to content

Commit

Permalink
Introducing BertForMultipleChoice transformer (#14435)
Browse files Browse the repository at this point in the history
* [SPARKNLP-1084] Introducing BertForMultipleChoice

* [SPARKNLP-1084] Introducing BertForMultipleChoice transformer
  • Loading branch information
danilojsl authored Oct 18, 2024
1 parent 2e3d1c2 commit a7e9c0e
Show file tree
Hide file tree
Showing 10 changed files with 860 additions and 13 deletions.
2 changes: 1 addition & 1 deletion python/sparknlp/annotator/classifier_dl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@
from sparknlp.annotator.classifier_dl.mpnet_for_token_classification import *
from sparknlp.annotator.classifier_dl.albert_for_zero_shot_classification import *
from sparknlp.annotator.classifier_dl.camembert_for_zero_shot_classification import *

from sparknlp.annotator.classifier_dl.bert_for_multiple_choice import *
161 changes: 161 additions & 0 deletions python/sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2017-2024 John Snow Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sparknlp.common import *

class BertForMultipleChoice(AnnotatorModel,
HasCaseSensitiveProperties,
HasBatchedAnnotate,
HasEngine,
HasMaxSentenceLengthLimit):
"""BertForMultipleChoice can load BERT Models with a multiple choice classification head on top
(a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
Pretrained models can be loaded with :meth:`.pretrained` of the companion
object:
>>> spanClassifier = BertForMultipleChoice.pretrained() \\
... .setInputCols(["document_question", "document_context"]) \\
... .setOutputCol("answer")
The default model is ``"bert_base_uncased_multiple_choice"``, if no name is
provided.
For available pretrained models please see the `Models Hub
<https://sparknlp.org/models?task=Multiple+Choice>`__.
To see which models are compatible and how to import them see
`Import Transformers into Spark NLP 🚀
<https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_.
====================== ======================
Input Annotation types Output Annotation type
====================== ======================
``DOCUMENT, DOCUMENT`` ``CHUNK``
====================== ======================
Parameters
----------
batchSize
Batch size. Large values allows faster processing but requires more
memory, by default 8
caseSensitive
Whether to ignore case in tokens for embeddings matching, by default
False
maxSentenceLength
Max sentence length to process, by default 512
Examples
--------
>>> import sparknlp
>>> from sparknlp.base import *
>>> from sparknlp.annotator import *
>>> from pyspark.ml import Pipeline
>>> documentAssembler = MultiDocumentAssembler() \\
... .setInputCols(["question", "context"]) \\
... .setOutputCols(["document_question", "document_context"])
>>> questionAnswering = BertForMultipleChoice.pretrained() \\
... .setInputCols(["document_question", "document_context"]) \\
... .setOutputCol("answer") \\
... .setCaseSensitive(False)
>>> pipeline = Pipeline().setStages([
... documentAssembler,
... questionAnswering
... ])
>>> data = spark.createDataFrame([["The Eiffel Tower is located in which country??", "Germany, France, Italy"]]).toDF("question", "context")
>>> result = pipeline.fit(data).transform(data)
>>> result.select("answer.result").show(truncate=False)
+--------------------+
|result |
+--------------------+
|[France] |
+--------------------+
"""
name = "BertForMultipleChoice"

inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.DOCUMENT]

outputAnnotatorType = AnnotatorType.CHUNK

choicesDelimiter = Param(Params._dummy(),
"choicesDelimiter",
"Delimiter character use to split the choices",
TypeConverters.toString)

def setChoicesDelimiter(self, value):
"""Sets delimiter character use to split the choices
Parameters
----------
value : string
Delimiter character use to split the choices
"""
return self._set(caseSensitive=value)

@keyword_only
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.BertForMultipleChoice",
java_model=None):
super(BertForMultipleChoice, self).__init__(
classname=classname,
java_model=java_model
)
self._setDefault(
batchSize=4,
maxSentenceLength=512,
caseSensitive=False,
choicesDelimiter = ","
)

@staticmethod
def loadSavedModel(folder, spark_session):
"""Loads a locally saved model.
Parameters
----------
folder : str
Folder of the saved model
spark_session : pyspark.sql.SparkSession
The current SparkSession
Returns
-------
BertForQuestionAnswering
The restored model
"""
from sparknlp.internal import _BertMultipleChoiceLoader
jModel = _BertMultipleChoiceLoader(folder, spark_session._jsparkSession)._java_obj
return BertForMultipleChoice(java_model=jModel)

@staticmethod
def pretrained(name="bert_base_uncased_multiple_choice", lang="en", remote_loc=None):
"""Downloads and loads a pretrained model.
Parameters
----------
name : str, optional
Name of the pretrained model, by default
"bert_base_uncased_multiple_choice"
lang : str, optional
Language of the pretrained model, by default "en"
remote_loc : str, optional
Optional remote address of the resource, by default None. Will use
Spark NLPs repositories otherwise.
Returns
-------
BertForQuestionAnswering
The restored model
"""
from sparknlp.pretrained import ResourceDownloader
return ResourceDownloader.downloadModel(BertForMultipleChoice, name, lang, remote_loc)
7 changes: 7 additions & 0 deletions python/sparknlp/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ def __init__(self, path, jspark):
jspark,
)

class _BertMultipleChoiceLoader(ExtendedJavaWrapper):
def __init__(self, path, jspark):
super(_BertMultipleChoiceLoader, self).__init__(
"com.johnsnowlabs.nlp.annotators.classifier.dl.BertForMultipleChoice.loadSavedModel",
path,
jspark,
)

class _DeBERTaLoader(ExtendedJavaWrapper):
def __init__(self, path, jspark):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2017-2024 John Snow Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import pytest

from sparknlp.annotator import *
from sparknlp.base import *
from test.util import SparkContextForTest


class BertForMultipleChoiceTestSetup(unittest.TestCase):
def setUp(self):
self.spark = SparkContextForTest.spark
self.question = "The Eiffel Tower is located in which country?"
self.choices = "Germany, France, Italy"

self.spark = SparkContextForTest.spark
empty_df = self.spark.createDataFrame([[""]]).toDF("text")

document_assembler = MultiDocumentAssembler() \
.setInputCols(["question", "context"]) \
.setOutputCols(["document_question", "document_context"])

bert_for_multiple_choice = BertForMultipleChoice.pretrained() \
.setInputCols(["document_question", "document_context"]) \
.setOutputCol("answer") \

pipeline = Pipeline(stages=[document_assembler, bert_for_multiple_choice])

self.pipeline_model = pipeline.fit(empty_df)


@pytest.mark.slow
class BertForMultipleChoiceTest(BertForMultipleChoiceTestSetup, unittest.TestCase):

def setUp(self):
super().setUp()
self.data = self.spark.createDataFrame([[self.question, self.choices]]).toDF("question","context")
self.data.show(truncate=False)

def test_run(self):
result_df = self.pipeline_model.transform(self.data)
result_df.show(truncate=False)
for row in result_df.collect():
self.assertTrue(row["answer"][0].result != "")


@pytest.mark.slow
class LightBertForMultipleChoiceTest(BertForMultipleChoiceTestSetup, unittest.TestCase):

def setUp(self):
super().setUp()

def runTest(self):
light_pipeline = LightPipeline(self.pipeline_model)
annotations_result = light_pipeline.fullAnnotate(self.question,self.choices)
print(annotations_result)
for result in annotations_result:
self.assertTrue(result["answer"][0].result != "")

result = light_pipeline.annotate(self.question,self.choices)
print(result)
self.assertTrue(result["answer"] != "")
92 changes: 91 additions & 1 deletion src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ private[johnsnowlabs] class BertClassification(

// we need the original form of the token
// let's lowercase if needed right before the encoding
val basicTokenizer = new BasicTokenizer(caseSensitive = true, hasBeginEnd = false)
val basicTokenizer = new BasicTokenizer(caseSensitive = caseSensitive, hasBeginEnd = false)
val encoder = new WordpieceEncoder(vocabulary)
val sentences = docs.map { s => Sentence(s.result, s.begin, s.end, 0) }

Expand Down Expand Up @@ -546,6 +546,15 @@ private[johnsnowlabs] class BertClassification(
(startScores, endScores)
}

override def tagSpanMultipleChoice(batch: Seq[Array[Int]]): Array[Float] = {
val logits = detectedEngine match {
case ONNX.name => computeLogitsMultipleChoiceWithOnnx(batch)
case Openvino.name => computeLogitsMultipleChoiceWithOv(batch)
}

calculateSoftmax(logits)
}

private def computeLogitsWithTF(
batch: Seq[Array[Int]],
maxSentenceLength: Int): (Array[Float], Array[Float]) = {
Expand Down Expand Up @@ -732,6 +741,87 @@ private[johnsnowlabs] class BertClassification(
}
}

private def computeLogitsMultipleChoiceWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {
val sequenceLength = batch.head.length
val inputIds = Array(batch.map(x => x.map(_.toLong)).toArray)
val attentionMask = Array(
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)
val tokenTypeIds = Array(batch.map(_ => Array.fill(sequenceLength)(0L)).toArray)

val (ortSession, ortEnv) = onnxWrapper.get.getSession(onnxSessionOptions)
val tokenTensors = OnnxTensor.createTensor(ortEnv, inputIds)
val maskTensors = OnnxTensor.createTensor(ortEnv, attentionMask)
val segmentTensors = OnnxTensor.createTensor(ortEnv, tokenTypeIds)

val inputs =
Map(
"input_ids" -> tokenTensors,
"attention_mask" -> maskTensors,
"token_type_ids" -> segmentTensors).asJava

try {
val output = ortSession.run(inputs)
try {

val logits = output
.get("logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

tokenTensors.close()
maskTensors.close()
segmentTensors.close()

logits
} finally if (output != null) output.close()
} catch {
case e: Exception =>
// Log the exception as a warning
println("Exception in computeLogitsMultipleChoiceWithOnnx: ", e)
// Rethrow the exception to propagate it further
throw e
}
}

private def computeLogitsMultipleChoiceWithOv(batch: Seq[Array[Int]]): Array[Float] = {
val (numChoices, sequenceLength) = (batch.length, batch.head.length)
// batch_size, num_choices, sequence_length
val shape = Some(Array(1, numChoices, sequenceLength))
val (tokenTensors, maskTensors, segmentTensors) =
PrepareEmbeddings.prepareOvLongBatchTensorsWithSegment(
batch,
sequenceLength,
numChoices,
sentencePadTokenId,
shape)

val compiledModel = openvinoWrapper.get.getCompiledModel()
val inferRequest = compiledModel.create_infer_request()
inferRequest.set_tensor("input_ids", tokenTensors)
inferRequest.set_tensor("attention_mask", maskTensors)
inferRequest.set_tensor("token_type_ids", segmentTensors)

inferRequest.infer()

try {
try {
val logits = inferRequest
.get_output_tensor()
.data()

logits
}
} catch {
case e: Exception =>
// Log the exception as a warning
logger.warn("Exception in computeLogitsMultipleChoiceWithOv", e)
// Rethrow the exception to propagate it further
throw e
}
}

def findIndexedToken(
tokenizedSentences: Seq[TokenizedSentence],
sentence: (WordpieceTokenizedSentence, Int),
Expand Down
Loading

0 comments on commit a7e9c0e

Please sign in to comment.