diff --git a/python/sparknlp/annotator/classifier_dl/__init__.py b/python/sparknlp/annotator/classifier_dl/__init__.py index bbd9f60a8dfbba..2b5e30fc3ff359 100644 --- a/python/sparknlp/annotator/classifier_dl/__init__.py +++ b/python/sparknlp/annotator/classifier_dl/__init__.py @@ -54,4 +54,4 @@ from sparknlp.annotator.classifier_dl.mpnet_for_token_classification import * from sparknlp.annotator.classifier_dl.albert_for_zero_shot_classification import * from sparknlp.annotator.classifier_dl.camembert_for_zero_shot_classification import * - +from sparknlp.annotator.classifier_dl.bert_for_multiple_choice import * diff --git a/python/sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py b/python/sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py new file mode 100644 index 00000000000000..2c27f913e56fcc --- /dev/null +++ b/python/sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py @@ -0,0 +1,161 @@ +# Copyright 2017-2024 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sparknlp.common import * + +class BertForMultipleChoice(AnnotatorModel, + HasCaseSensitiveProperties, + HasBatchedAnnotate, + HasEngine, + HasMaxSentenceLengthLimit): + """BertForMultipleChoice can load BERT Models with a multiple choice classification head on top + (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. + + Pretrained models can be loaded with :meth:`.pretrained` of the companion + object: + + >>> spanClassifier = BertForMultipleChoice.pretrained() \\ + ... .setInputCols(["document_question", "document_context"]) \\ + ... .setOutputCol("answer") + + The default model is ``"bert_base_uncased_multiple_choice"``, if no name is + provided. + + For available pretrained models please see the `Models Hub + `__. + + To see which models are compatible and how to import them see + `Import Transformers into Spark NLP 🚀 + `_. + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT, DOCUMENT`` ``CHUNK`` + ====================== ====================== + + Parameters + ---------- + batchSize + Batch size. Large values allows faster processing but requires more + memory, by default 8 + caseSensitive + Whether to ignore case in tokens for embeddings matching, by default + False + maxSentenceLength + Max sentence length to process, by default 512 + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = MultiDocumentAssembler() \\ + ... .setInputCols(["question", "context"]) \\ + ... .setOutputCols(["document_question", "document_context"]) + >>> questionAnswering = BertForMultipleChoice.pretrained() \\ + ... .setInputCols(["document_question", "document_context"]) \\ + ... .setOutputCol("answer") \\ + ... .setCaseSensitive(False) + >>> pipeline = Pipeline().setStages([ + ... documentAssembler, + ... questionAnswering + ... ]) + >>> data = spark.createDataFrame([["The Eiffel Tower is located in which country??", "Germany, France, Italy"]]).toDF("question", "context") + >>> result = pipeline.fit(data).transform(data) + >>> result.select("answer.result").show(truncate=False) + +--------------------+ + |result | + +--------------------+ + |[France] | + +--------------------+ + """ + name = "BertForMultipleChoice" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.DOCUMENT] + + outputAnnotatorType = AnnotatorType.CHUNK + + choicesDelimiter = Param(Params._dummy(), + "choicesDelimiter", + "Delimiter character use to split the choices", + TypeConverters.toString) + + def setChoicesDelimiter(self, value): + """Sets delimiter character use to split the choices + + Parameters + ---------- + value : string + Delimiter character use to split the choices + """ + return self._set(caseSensitive=value) + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.BertForMultipleChoice", + java_model=None): + super(BertForMultipleChoice, self).__init__( + classname=classname, + java_model=java_model + ) + self._setDefault( + batchSize=4, + maxSentenceLength=512, + caseSensitive=False, + choicesDelimiter = "," + ) + + @staticmethod + def loadSavedModel(folder, spark_session): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + BertForQuestionAnswering + The restored model + """ + from sparknlp.internal import _BertMultipleChoiceLoader + jModel = _BertMultipleChoiceLoader(folder, spark_session._jsparkSession)._java_obj + return BertForMultipleChoice(java_model=jModel) + + @staticmethod + def pretrained(name="bert_base_uncased_multiple_choice", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default + "bert_base_uncased_multiple_choice" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + BertForQuestionAnswering + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(BertForMultipleChoice, name, lang, remote_loc) \ No newline at end of file diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index c8732ef3ecb4e5..1ed209782bd18c 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -113,6 +113,13 @@ def __init__(self, path, jspark): jspark, ) +class _BertMultipleChoiceLoader(ExtendedJavaWrapper): + def __init__(self, path, jspark): + super(_BertMultipleChoiceLoader, self).__init__( + "com.johnsnowlabs.nlp.annotators.classifier.dl.BertForMultipleChoice.loadSavedModel", + path, + jspark, + ) class _DeBERTaLoader(ExtendedJavaWrapper): def __init__(self, path, jspark): diff --git a/python/test/annotator/classifier_dl/bert_for_multiple_choice_test.py b/python/test/annotator/classifier_dl/bert_for_multiple_choice_test.py new file mode 100644 index 00000000000000..369ecd44374b19 --- /dev/null +++ b/python/test/annotator/classifier_dl/bert_for_multiple_choice_test.py @@ -0,0 +1,76 @@ +# Copyright 2017-2024 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from test.util import SparkContextForTest + + +class BertForMultipleChoiceTestSetup(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + self.question = "The Eiffel Tower is located in which country?" + self.choices = "Germany, France, Italy" + + self.spark = SparkContextForTest.spark + empty_df = self.spark.createDataFrame([[""]]).toDF("text") + + document_assembler = MultiDocumentAssembler() \ + .setInputCols(["question", "context"]) \ + .setOutputCols(["document_question", "document_context"]) + + bert_for_multiple_choice = BertForMultipleChoice.pretrained() \ + .setInputCols(["document_question", "document_context"]) \ + .setOutputCol("answer") \ + + pipeline = Pipeline(stages=[document_assembler, bert_for_multiple_choice]) + + self.pipeline_model = pipeline.fit(empty_df) + + +@pytest.mark.slow +class BertForMultipleChoiceTest(BertForMultipleChoiceTestSetup, unittest.TestCase): + + def setUp(self): + super().setUp() + self.data = self.spark.createDataFrame([[self.question, self.choices]]).toDF("question","context") + self.data.show(truncate=False) + + def test_run(self): + result_df = self.pipeline_model.transform(self.data) + result_df.show(truncate=False) + for row in result_df.collect(): + self.assertTrue(row["answer"][0].result != "") + + +@pytest.mark.slow +class LightBertForMultipleChoiceTest(BertForMultipleChoiceTestSetup, unittest.TestCase): + + def setUp(self): + super().setUp() + + def runTest(self): + light_pipeline = LightPipeline(self.pipeline_model) + annotations_result = light_pipeline.fullAnnotate(self.question,self.choices) + print(annotations_result) + for result in annotations_result: + self.assertTrue(result["answer"][0].result != "") + + result = light_pipeline.annotate(self.question,self.choices) + print(result) + self.assertTrue(result["answer"] != "") diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala index e8ed6f51d2ff17..15f9345c3da88b 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala @@ -130,7 +130,7 @@ private[johnsnowlabs] class BertClassification( // we need the original form of the token // let's lowercase if needed right before the encoding - val basicTokenizer = new BasicTokenizer(caseSensitive = true, hasBeginEnd = false) + val basicTokenizer = new BasicTokenizer(caseSensitive = caseSensitive, hasBeginEnd = false) val encoder = new WordpieceEncoder(vocabulary) val sentences = docs.map { s => Sentence(s.result, s.begin, s.end, 0) } @@ -546,6 +546,15 @@ private[johnsnowlabs] class BertClassification( (startScores, endScores) } + override def tagSpanMultipleChoice(batch: Seq[Array[Int]]): Array[Float] = { + val logits = detectedEngine match { + case ONNX.name => computeLogitsMultipleChoiceWithOnnx(batch) + case Openvino.name => computeLogitsMultipleChoiceWithOv(batch) + } + + calculateSoftmax(logits) + } + private def computeLogitsWithTF( batch: Seq[Array[Int]], maxSentenceLength: Int): (Array[Float], Array[Float]) = { @@ -732,6 +741,87 @@ private[johnsnowlabs] class BertClassification( } } + private def computeLogitsMultipleChoiceWithOnnx(batch: Seq[Array[Int]]): Array[Float] = { + val sequenceLength = batch.head.length + val inputIds = Array(batch.map(x => x.map(_.toLong)).toArray) + val attentionMask = Array( + batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) + val tokenTypeIds = Array(batch.map(_ => Array.fill(sequenceLength)(0L)).toArray) + + val (ortSession, ortEnv) = onnxWrapper.get.getSession(onnxSessionOptions) + val tokenTensors = OnnxTensor.createTensor(ortEnv, inputIds) + val maskTensors = OnnxTensor.createTensor(ortEnv, attentionMask) + val segmentTensors = OnnxTensor.createTensor(ortEnv, tokenTypeIds) + + val inputs = + Map( + "input_ids" -> tokenTensors, + "attention_mask" -> maskTensors, + "token_type_ids" -> segmentTensors).asJava + + try { + val output = ortSession.run(inputs) + try { + + val logits = output + .get("logits") + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + + logits + } finally if (output != null) output.close() + } catch { + case e: Exception => + // Log the exception as a warning + println("Exception in computeLogitsMultipleChoiceWithOnnx: ", e) + // Rethrow the exception to propagate it further + throw e + } + } + + private def computeLogitsMultipleChoiceWithOv(batch: Seq[Array[Int]]): Array[Float] = { + val (numChoices, sequenceLength) = (batch.length, batch.head.length) + // batch_size, num_choices, sequence_length + val shape = Some(Array(1, numChoices, sequenceLength)) + val (tokenTensors, maskTensors, segmentTensors) = + PrepareEmbeddings.prepareOvLongBatchTensorsWithSegment( + batch, + sequenceLength, + numChoices, + sentencePadTokenId, + shape) + + val compiledModel = openvinoWrapper.get.getCompiledModel() + val inferRequest = compiledModel.create_infer_request() + inferRequest.set_tensor("input_ids", tokenTensors) + inferRequest.set_tensor("attention_mask", maskTensors) + inferRequest.set_tensor("token_type_ids", segmentTensors) + + inferRequest.infer() + + try { + try { + val logits = inferRequest + .get_output_tensor() + .data() + + logits + } + } catch { + case e: Exception => + // Log the exception as a warning + logger.warn("Exception in computeLogitsMultipleChoiceWithOv", e) + // Rethrow the exception to propagate it further + throw e + } + } + def findIndexedToken( tokenizedSentences: Seq[TokenizedSentence], sentence: (WordpieceTokenizedSentence, Int), diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala index 919d6aa0d17c6e..af40658d46168d 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala @@ -304,6 +304,43 @@ private[johnsnowlabs] trait XXXForClassification { } + def predictSpanMultipleChoice( + documents: Seq[Annotation], + choicesDelimiter: String, + maxSentenceLength: Int, + caseSensitive: Boolean): Seq[Annotation] = { + + val questionAnnotation = Seq(documents.head) + val choices = + documents.drop(1).flatMap(annotation => annotation.result.split(choicesDelimiter)) + + val wordPieceTokenizedQuestions = + tokenizeDocument(questionAnnotation, maxSentenceLength, caseSensitive) + + val inputIds = choices.flatMap { choice => + val choiceAnnotation = + Seq(Annotation(AnnotatorType.DOCUMENT, 0, choice.length, choice, Map("sentence" -> "0"))) + val wordPieceTokenizedChoice = + tokenizeDocument(choiceAnnotation, maxSentenceLength, caseSensitive) + encodeSequenceWithPadding( + wordPieceTokenizedQuestions, + wordPieceTokenizedChoice, + maxSentenceLength) + } + + val scores = tagSpanMultipleChoice(inputIds) + val (score, scoreIndex) = scores.zipWithIndex.maxBy(_._1) + val prediction = choices(scoreIndex) + + Seq( + Annotation( + annotatorType = AnnotatorType.CHUNK, + begin = 0, + end = if (prediction.isEmpty) 0 else prediction.length - 1, + result = prediction, + metadata = Map("sentence" -> "0", "chunk" -> "0", "score" -> score.toString))) + } + def tokenizeWithAlignment( sentences: Seq[TokenizedSentence], maxSeqLength: Int, @@ -362,6 +399,38 @@ private[johnsnowlabs] trait XXXForClassification { Seq(Array(sentenceStartTokenId) ++ question ++ context) } + def encodeSequenceWithPadding( + seq1: Seq[WordpieceTokenizedSentence], + seq2: Seq[WordpieceTokenizedSentence], + maxSequenceLength: Int): Seq[Array[Int]] = { + + val question = seq1.flatMap { wpTokSentence => + wpTokSentence.tokens.map(t => t.pieceId) + }.toArray + + val context = seq2.flatMap { wpTokSentence => + wpTokSentence.tokens.map(t => t.pieceId) + }.toArray + + val availableLength = maxSequenceLength - 3 // (excluding special tokens) + val truncatedQuestion = question.take(availableLength) + val remainingLength = availableLength - truncatedQuestion.length + val truncatedContext = context.take(remainingLength) + + val assembleSequence = + Array(sentenceStartTokenId) ++ truncatedQuestion ++ Array(sentenceEndTokenId) ++ + truncatedContext ++ Array(sentenceEndTokenId) + + val paddingLength = maxSequenceLength - assembleSequence.length + val paddedSequence = if (paddingLength > 0) { + assembleSequence ++ Array.fill(paddingLength)(sentencePadTokenId) + } else { + assembleSequence + } + + Seq(paddedSequence) + } + def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] @@ -374,6 +443,8 @@ private[johnsnowlabs] trait XXXForClassification { def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) + def tagSpanMultipleChoice(batch: Seq[Array[Int]]): Array[Float] = Array() + /** Calculate softmax from returned logits * @param scores * logits output from output layer diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/util/PrepareEmbeddings.scala b/src/main/scala/com/johnsnowlabs/ml/ai/util/PrepareEmbeddings.scala index ddb85236678326..6529697fcfde74 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/util/PrepareEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/util/PrepareEmbeddings.scala @@ -82,18 +82,42 @@ private[johnsnowlabs] object PrepareEmbeddings { batch: Seq[Array[Int]], maxSentenceLength: Int, batchLength: Int, - sentencePadTokenId: Int = 0) + sentencePadTokenId: Int = 0, + shape: Option[Array[Int]] = None) : (org.intel.openvino.Tensor, org.intel.openvino.Tensor, org.intel.openvino.Tensor) = { - val shape = Array(batchLength, maxSentenceLength) - val tokenTensors = - new org.intel.openvino.Tensor(shape, batch.flatten.toArray) - val maskTensors = new org.intel.openvino.Tensor( - shape, - batch - .flatMap(sentence => sentence.map(x => if (x == sentencePadTokenId) 0 else 1)) - .toArray) + val tensorsShape = if (shape.isDefined) shape.get else Array(batchLength, maxSentenceLength) + val inputIds = batch.flatten.toArray + val attentionMask = batch + .flatMap(sentence => sentence.map(x => if (x == sentencePadTokenId) 0 else 1)) + .toArray + + val tokenTensors = new org.intel.openvino.Tensor(tensorsShape, inputIds) + val maskTensors = new org.intel.openvino.Tensor(tensorsShape, attentionMask) + + val segmentTensors = + new org.intel.openvino.Tensor(tensorsShape, Array.fill(batchLength * maxSentenceLength)(0)) + + (tokenTensors, maskTensors, segmentTensors) + } + + def prepareOvLongBatchTensorsWithSegment( + batch: Seq[Array[Int]], + maxSentenceLength: Int, + batchLength: Int, + sentencePadTokenId: Int = 0, + shape: Option[Array[Int]] = None) + : (org.intel.openvino.Tensor, org.intel.openvino.Tensor, org.intel.openvino.Tensor) = { + val tensorsShape = if (shape.isDefined) shape.get else Array(batchLength, maxSentenceLength) + val inputIds = batch.flatMap(x => x.map(xx => xx.toLong)).toArray + val attentionMask = batch + .flatMap(sentence => sentence.map(x => if (x == sentencePadTokenId) 0L else 1L)) + .toArray + + val tokenTensors = new org.intel.openvino.Tensor(tensorsShape, inputIds) + val maskTensors = new org.intel.openvino.Tensor(tensorsShape, attentionMask) + val segmentTensors = - new org.intel.openvino.Tensor(shape, Array.fill(batchLength * maxSentenceLength)(0)) + new org.intel.openvino.Tensor(tensorsShape, Array.fill(batchLength * maxSentenceLength)(0L)) (tokenTensors, maskTensors, segmentTensors) } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForMultipleChoice.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForMultipleChoice.scala new file mode 100644 index 00000000000000..eb2bd85580ed46 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForMultipleChoice.scala @@ -0,0 +1,334 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.classifier.dl + +import com.johnsnowlabs.ml.ai.BertClassification +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.openvino.{OpenvinoWrapper, ReadOpenvinoModel, WriteOpenvinoModel} +import com.johnsnowlabs.ml.tensorflow.TensorflowWrapper +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.{ONNX, Openvino} +import com.johnsnowlabs.nlp.serialization.MapFeature +import com.johnsnowlabs.nlp._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param.{IntParam, Param} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.SparkSession + +/** BertForMultipleChoice can load BERT Models with a multiple choice classification head on top + * (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. + * + * Pretrained models can be loaded with `pretrained` of the companion object: + * {{{ + * val spanClassifier = BertForMultipleChoice.pretrained() + * .setInputCols(Array("document_question", "document_context")) + * .setOutputCol("answer") + * }}} + * The default model is `"bert_base_uncased_multiple_choice"`, if no name is provided. + * + * For available pretrained models please see the + * [[https://sparknlp.org/models?task=Multiple+Choice Models Hub]]. + * + * Models from the HuggingFace 🤗 Transformers library are also compatible with Spark NLP 🚀. To + * see which models are compatible and how to import them see + * [[https://github.com/JohnSnowLabs/spark-nlp/discussions/5669]] and to see more extended + * examples, see + * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForMultipleChoiceTestSpec.scala BertForMultipleChoiceTestSpec]]. + * + * ==Example== + * {{{ + * import spark.implicits._ + * import com.johnsnowlabs.nlp.base._ + * import com.johnsnowlabs.nlp.annotator._ + * import org.apache.spark.ml.Pipeline + * + * val document = new MultiDocumentAssembler() + * .setInputCols("question", "context") + * .setOutputCols("document_question", "document_context") + * + * val questionAnswering = BertForMultipleChoice.pretrained() + * .setInputCols(Array("document_question", "document_context")) + * .setOutputCol("answer") + * .setCaseSensitive(false) + * + * val pipeline = new Pipeline().setStages(Array( + * document, + * questionAnswering + * )) + * + * val data = Seq("The Eiffel Tower is located in which country?", "Germany, France, Italy").toDF("question", "context") + * val result = pipeline.fit(data).transform(data) + * + * result.select("answer.result").show(false) + * +---------------------+ + * |result | + * +---------------------+ + * |[France] | + * ++--------------------+ + * }}} + * + * @see + * [[BertForQuestionAnswering]] for Question Answering tasks + * @see + * [[https://sparknlp.org/docs/en/annotators Annotators Main Page]] for a list of transformer + * based classifiers + * @param uid + * required uid for storing annotator to disk + * @groupname anno Annotator types + * @groupdesc anno + * Required input and expected output annotator types + * @groupname Ungrouped Members + * @groupname param Parameters + * @groupname setParam Parameter setters + * @groupname getParam Parameter getters + * @groupname Ungrouped Members + * @groupprio param 1 + * @groupprio anno 2 + * @groupprio Ungrouped 3 + * @groupprio setParam 4 + * @groupprio getParam 5 + * @groupdesc param + * A list of (hyper-)parameter keys this annotator can take. Users can set and get the + * parameter values through setters and getters, respectively. + */ + +class BertForMultipleChoice(override val uid: String) + extends AnnotatorModel[BertForMultipleChoice] + with HasBatchedAnnotate[BertForMultipleChoice] + with WriteOnnxModel + with WriteOpenvinoModel + with HasCaseSensitiveProperties + with HasEngine { + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + def this() = this(Identifiable.randomUID("BertForMultipleChoice")) + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + override val inputAnnotatorTypes: Array[AnnotatorType] = + Array(AnnotatorType.DOCUMENT, AnnotatorType.DOCUMENT) + override val outputAnnotatorType: AnnotatorType = AnnotatorType.CHUNK + + /** Vocabulary used to encode the words to ids with WordPieceEncoder + * + * @group param + */ + val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected() + + /** @group setParam */ + def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value) + + /** @group setParam */ + def sentenceStartTokenId: Int = { + $$(vocabulary)("[CLS]") + } + + /** @group setParam */ + def sentenceEndTokenId: Int = { + $$(vocabulary)("[SEP]") + } + + /** Max sentence length to process (Default: `512`) + * + * @group param + */ + val maxSentenceLength = + new IntParam(this, "maxSentenceLength", "Max sentence length to process") + + /** @group setParam */ + def setMaxSentenceLength(value: Int): this.type = { + require( + value <= 512, + "BERT models do not support sequences longer than 512 because of trainable positional embeddings.") + require(value >= 1, "The maxSentenceLength must be at least 1") + set(maxSentenceLength, value) + this + } + + val choicesDelimiter = + new Param[String](this, "choicesDelimiter", "Delimiter character use to split the choices") + + def setChoicesDelimiter(value: String): this.type = set(choicesDelimiter, value) + + private var _model: Option[Broadcast[BertClassification]] = None + + /** @group setParam */ + def setModelIfNotSet( + spark: SparkSession, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper], + openvinoWrapper: Option[OpenvinoWrapper]): BertForMultipleChoice = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new BertClassification( + tensorflowWrapper, + onnxWrapper, + openvinoWrapper, + sentenceStartTokenId, + sentenceEndTokenId, + configProtoBytes = None, + tags = Map.empty[String, Int], + signatures = None, + vocabulary = $$(vocabulary)))) + } + + this + } + + /** @group getParam */ + def getModelIfNotSet: BertClassification = _model.get.value + + setDefault( + batchSize -> 4, + maxSentenceLength -> 512, + caseSensitive -> false, + choicesDelimiter -> ",") + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations in batches that correspond to inputAnnotationCols generated by previous + * annotators if any + * @return + * any number of annotations processed for every batch of input annotations. Not necessary + * one to one relationship + * + * IMPORTANT: !MUST! return sequences of equal lengths !! IMPORTANT: !MUST! return sentences + * that belong to the same original row !! (challenging) + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + batchedAnnotations.map(annotations => { + if (annotations.nonEmpty) { + getModelIfNotSet.predictSpanMultipleChoice( + annotations, + $(choicesDelimiter), + $(maxSentenceLength), + $(caseSensitive)) + } else { + Seq.empty[Annotation] + } + }) + } + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + getEngine match { + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + "_bert_multiple_choice_classification", + BertForMultipleChoice.onnxFile) + case Openvino.name => + writeOpenvinoModel( + path, + spark, + getModelIfNotSet.openvinoWrapper.get, + "openvino_model.xml", + BertForMultipleChoice.openvinoFile) + + } + } + +} + +trait ReadablePretrainedBertForMultipleChoiceModel + extends ParamsAndFeaturesReadable[BertForMultipleChoice] + with HasPretrained[BertForMultipleChoice] { + override val defaultModelName: Some[String] = Some("bert_base_uncased_multiple_choice") + + /** Java compliant-overrides */ + override def pretrained(): BertForMultipleChoice = super.pretrained() + + override def pretrained(name: String): BertForMultipleChoice = super.pretrained(name) + + override def pretrained(name: String, lang: String): BertForMultipleChoice = + super.pretrained(name, lang) + + override def pretrained(name: String, lang: String, remoteLoc: String): BertForMultipleChoice = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadBertForMultipleChoiceModel extends ReadOnnxModel with ReadOpenvinoModel { + this: ParamsAndFeaturesReadable[BertForMultipleChoice] => + + override val onnxFile: String = "bert_mc_classification_onnx" + override val openvinoFile: String = "bert_mc_classification_openvino" + + def readModel(instance: BertForMultipleChoice, path: String, spark: SparkSession): Unit = { + instance.getEngine match { + case ONNX.name => + val onnxWrapper = + readOnnxModel(path, spark, "bert_mc_classification_onnx") + instance.setModelIfNotSet(spark, None, Some(onnxWrapper), None) + case Openvino.name => + val openvinoWrapper = readOpenvinoModel(path, spark, "bert_mc_classification_ov") + instance.setModelIfNotSet(spark, None, None, Some(openvinoWrapper)) + case _ => + throw new Exception(notSupportedEngineError) + } + } + + addReader(readModel) + + def loadSavedModel(modelPath: String, spark: SparkSession): BertForMultipleChoice = { + val (localModelPath, detectedEngine) = modelSanityCheck(modelPath) + val vocabs = loadTextAsset(localModelPath, "vocab.txt").zipWithIndex.toMap + val annotatorModel = new BertForMultipleChoice().setVocabulary(vocabs) + annotatorModel.set(annotatorModel.engine, detectedEngine) + + detectedEngine match { + case ONNX.name => + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper), None) + case Openvino.name => + val ovWrapper: OpenvinoWrapper = + OpenvinoWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + detectedEngine = detectedEngine) + annotatorModel + .setModelIfNotSet(spark, None, None, Some(ovWrapper)) + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } + +} + +/** This is the companion object of [[BertForMultipleChoice]]. Please refer to that class for the + * documentation. + */ +object BertForMultipleChoice + extends ReadablePretrainedBertForMultipleChoiceModel + with ReadBertForMultipleChoiceModel diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala index d0ba5238deedaa..f271566e04715a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala @@ -688,7 +688,9 @@ object PythonResourceDownloader { "AlbertForZeroShotClassification" -> AlbertForZeroShotClassification, "MxbaiEmbeddings" -> MxbaiEmbeddings, "SnowFlakeEmbeddings" -> SnowFlakeEmbeddings, - "CamemBertForZeroShotClassification" -> CamemBertForZeroShotClassification) + "CamemBertForZeroShotClassification" -> CamemBertForZeroShotClassification, + "BertForMultipleChoice" -> BertForMultipleChoice + ) // List pairs of types such as the one with key type can load a pretrained model from the value type val typeMapper: Map[String, String] = Map("ZeroShotNerModel" -> "RoBertaForQuestionAnswering") diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForMultipleChoiceTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForMultipleChoiceTestSpec.scala new file mode 100644 index 00000000000000..6aebffb53e8083 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForMultipleChoiceTestSpec.scala @@ -0,0 +1,82 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.classifier.dl + +import com.johnsnowlabs.nlp.annotators.SparkSessionTest +import com.johnsnowlabs.nlp.base.LightPipeline +import com.johnsnowlabs.nlp.{Annotation, AssertAnnotations, MultiDocumentAssembler} +import com.johnsnowlabs.tags.SlowTest +import org.apache.spark.ml.Pipeline +import org.scalatest.flatspec.AnyFlatSpec + +class BertForMultipleChoiceTestSpec extends AnyFlatSpec with SparkSessionTest { + + import spark.implicits._ + + lazy val pipelineModel = getBertForMultipleChoicePipelineModel + + val testDataframe = Seq( + ("The Eiffel Tower is located in which country?", "Germany, France, Italy")) + .toDF("question", "context") + + "BertForMultipleChoiceTestSpec" should "answer a multiple choice question" taggedAs SlowTest in { + val resultDf = pipelineModel.transform(testDataframe) + resultDf.show(truncate=false) + + val result = AssertAnnotations.getActualResult(resultDf, "answer") + result.foreach { annotation => + annotation.foreach(a => assert(a.result.nonEmpty)) + } + } + + it should "work with light pipeline fullAnnotate" taggedAs SlowTest in { + val lightPipeline = new LightPipeline(pipelineModel) + val resultFullAnnotate = lightPipeline.fullAnnotate( + "The Eiffel Tower is located in which country?", + "Germany, France, Italy") + println(s"resultAnnotate: $resultFullAnnotate") + + val answerAnnotation = resultFullAnnotate("answer").head.asInstanceOf[Annotation] + + assert(answerAnnotation.result.nonEmpty) + } + + it should "work with light pipeline annotate" taggedAs SlowTest in { + val lightPipeline = new LightPipeline(pipelineModel) + val resultAnnotate = lightPipeline.annotate( + "The Eiffel Tower is located in which country?", + "Germany, France, Italy") + println(s"resultAnnotate: $resultAnnotate") + + assert(resultAnnotate("answer").head.nonEmpty) + } + + private def getBertForMultipleChoicePipelineModel = { + val documentAssembler = new MultiDocumentAssembler() + .setInputCols("question", "context") + .setOutputCols("document_question", "document_context") + + val bertForMultipleChoice = BertForMultipleChoice.pretrained() + .setInputCols("document_question", "document_context") + .setOutputCol("answer") + + val pipeline = new Pipeline().setStages(Array(documentAssembler, bertForMultipleChoice)) + + pipeline.fit(emptyDataSet) + } + +}