-
Notifications
You must be signed in to change notification settings - Fork 717
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARKNLP-1068] Introducing BLIPForQuestionAnswering transformer (#14422
) * [SPARKNLP-1068] Introducing BLIPForQuestionAnswering transformer * [SPARKNLP-1068] Adding BLIPForQuestionAnswering import notebook example * [SPARKNLP-1068] Fix fullAnnotateImage validation * [SPARKNLP-1068] Solves BLIPForQuestionAnsweringTest issue * [SPARKNLP-1068] Updates default BLIPForQuestionAnswering model name * [SPARKNLP-1068] [skip test] Adding documentation to BLIPForQuestionAnswering
- Loading branch information
Showing
22 changed files
with
4,734 additions
and
61 deletions.
There are no files selected for viewing
3,425 changes: 3,425 additions & 0 deletions
3,425
examples/python/transformers/HuggingFace_in_Spark_NLP_BLIPForQuestionAnswering.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
172 changes: 172 additions & 0 deletions
172
python/sparknlp/annotator/cv/blip_for_question_answering.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# Copyright 2017-2024 John Snow Labs | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from sparknlp.common import * | ||
|
||
class BLIPForQuestionAnswering(AnnotatorModel, | ||
HasBatchedAnnotateImage, | ||
HasImageFeatureProperties, | ||
HasEngine, | ||
HasCandidateLabelsProperties, | ||
HasRescaleFactor): | ||
"""BLIPForQuestionAnswering can load BLIP models for visual question answering. | ||
The model consists of a vision encoder, a text encoder as well as a text decoder. | ||
The vision encoder will encode the input image, the text encoder will encode the input question together | ||
with the encoding of the image, and the text decoder will output the answer to the question. | ||
Pretrained models can be loaded with :meth:`.pretrained` of the companion | ||
object: | ||
>>> visualQAClassifier = BLIPForQuestionAnswering.pretrained() \\ | ||
... .setInputCols(["image_assembler"]) \\ | ||
... .setOutputCol("answer") | ||
The default model is ``"blip_vqa_base"``, if no name is | ||
provided. | ||
For available pretrained models please see the `Models Hub | ||
<https://sparknlp.org/models?task=Question+Answering>`__. | ||
To see which models are compatible and how to import them see | ||
`Import Transformers into Spark NLP 🚀 | ||
<https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_. | ||
====================== ====================== | ||
Input Annotation types Output Annotation type | ||
====================== ====================== | ||
``IMAGE`` ``DOCUMENT`` | ||
====================== ====================== | ||
Parameters | ||
---------- | ||
batchSize | ||
Batch size. Large values allows faster processing but requires more | ||
memory, by default 2 | ||
configProtoBytes | ||
ConfigProto from tensorflow, serialized into byte array. | ||
maxSentenceLength | ||
Max sentence length to process, by default 50 | ||
Examples | ||
-------- | ||
>>> import sparknlp | ||
>>> from sparknlp.base import * | ||
>>> from sparknlp.annotator import * | ||
>>> from pyspark.ml import Pipeline | ||
>>> image_df = SparkSessionForTest.spark.read.format("image").load(path=images_path) | ||
>>> test_df = image_df.withColumn("text", lit("What's this picture about?")) | ||
>>> imageAssembler = ImageAssembler() \\ | ||
... .setInputCol("image") \\ | ||
... .setOutputCol("image_assembler") | ||
>>> visualQAClassifier = BLIPForQuestionAnswering.pretrained() \\ | ||
... .setInputCols("image_assembler") \\ | ||
... .setOutputCol("answer") \\ | ||
... .setSize(384) | ||
>>> pipeline = Pipeline().setStages([ | ||
... imageAssembler, | ||
... visualQAClassifier | ||
... ]) | ||
>>> result = pipeline.fit(test_df).transform(test_df) | ||
>>> result.select("image_assembler.origin", "answer.result").show(false) | ||
+--------------------------------------+------+ | ||
|origin |result| | ||
+--------------------------------------+------+ | ||
|[file:///content/images/cat_image.jpg]|[cats]| | ||
+--------------------------------------+------+ | ||
""" | ||
|
||
name = "BLIPForQuestionAnswering" | ||
|
||
inputAnnotatorTypes = [AnnotatorType.IMAGE] | ||
|
||
outputAnnotatorType = AnnotatorType.DOCUMENT | ||
|
||
configProtoBytes = Param(Params._dummy(), | ||
"configProtoBytes", | ||
"ConfigProto from tensorflow, serialized into byte array. Get with " | ||
"config_proto.SerializeToString()", | ||
TypeConverters.toListInt) | ||
|
||
maxSentenceLength = Param(Params._dummy(), | ||
"maxSentenceLength", | ||
"Maximum sentence length that the annotator will process. Above this, the sentence is skipped", | ||
typeConverter=TypeConverters.toInt) | ||
|
||
def setMaxSentenceSize(self, value): | ||
"""Sets Maximum sentence length that the annotator will process, by | ||
default 50. | ||
Parameters | ||
---------- | ||
value : int | ||
Maximum sentence length that the annotator will process | ||
""" | ||
return self._set(maxSentenceLength=value) | ||
|
||
|
||
@keyword_only | ||
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.cv.BLIPForQuestionAnswering", | ||
java_model=None): | ||
super(BLIPForQuestionAnswering, self).__init__( | ||
classname=classname, | ||
java_model=java_model | ||
) | ||
self._setDefault( | ||
batchSize=2, | ||
size=384, | ||
maxSentenceLength=50 | ||
) | ||
|
||
@staticmethod | ||
def loadSavedModel(folder, spark_session): | ||
"""Loads a locally saved model. | ||
Parameters | ||
---------- | ||
folder : str | ||
Folder of the saved model | ||
spark_session : pyspark.sql.SparkSession | ||
The current SparkSession | ||
Returns | ||
------- | ||
CLIPForZeroShotClassification | ||
The restored model | ||
""" | ||
from sparknlp.internal import _BLIPForQuestionAnswering | ||
jModel = _BLIPForQuestionAnswering(folder, spark_session._jsparkSession)._java_obj | ||
return BLIPForQuestionAnswering(java_model=jModel) | ||
|
||
@staticmethod | ||
def pretrained(name="blip_vqa_base", lang="en", remote_loc=None): | ||
"""Downloads and loads a pretrained model. | ||
Parameters | ||
---------- | ||
name : str, optional | ||
Name of the pretrained model, by default | ||
"blip_vqa_tf" | ||
lang : str, optional | ||
Language of the pretrained model, by default "en" | ||
remote_loc : str, optional | ||
Optional remote address of the resource, by default None. Will use | ||
Spark NLPs repositories otherwise. | ||
Returns | ||
------- | ||
CLIPForZeroShotClassification | ||
The restored model | ||
""" | ||
from sparknlp.pretrained import ResourceDownloader | ||
return ResourceDownloader.downloadModel(BLIPForQuestionAnswering, name, lang, remote_loc) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
python/test/annotator/cv/blip_for_question_answering_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright 2017-2024 John Snow Labs | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import unittest | ||
import pytest | ||
import os | ||
|
||
from sparknlp.annotator import * | ||
from sparknlp.base import * | ||
from pyspark.sql.functions import lit | ||
from test.util import SparkSessionForTest | ||
|
||
|
||
class BLIPForQuestionAnsweringTestSetup(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.images_path = os.getcwd() + "/../src/test/resources/image/" | ||
image_df = SparkSessionForTest.spark.read.format("image").load( | ||
path=self.images_path | ||
) | ||
|
||
self.test_df = image_df.withColumn("text", lit("What's this picture about?")) | ||
|
||
image_assembler = ImageAssembler().setInputCol("image").setOutputCol("image_assembler") | ||
|
||
imageClassifier = BLIPForQuestionAnswering.pretrained() \ | ||
.setInputCols("image_assembler") \ | ||
.setOutputCol("answer") \ | ||
.setSize(384) | ||
|
||
self.pipeline = Pipeline( | ||
stages=[ | ||
image_assembler, | ||
imageClassifier, | ||
] | ||
) | ||
|
||
self.model = self.pipeline.fit(self.test_df) | ||
|
||
@pytest.mark.slow | ||
class BLIPForQuestionAnsweringTest(BLIPForQuestionAnsweringTestSetup, unittest.TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
|
||
def runTest(self): | ||
result = self.model.transform(self.test_df).collect() | ||
|
||
for row in result: | ||
self.assertTrue(row["answer"] != "") | ||
|
||
|
||
@pytest.mark.slow | ||
class LightBLIPForQuestionAnsweringTest(BLIPForQuestionAnsweringTestSetup, unittest.TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
|
||
def runTest(self): | ||
light_pipeline = LightPipeline(self.model) | ||
image_path = self.images_path + "bluetick.jpg" | ||
print("image_path: " + image_path) | ||
annotations_result = light_pipeline.fullAnnotateImage( | ||
image_path, | ||
"What's this picture about?" | ||
) | ||
|
||
for result in annotations_result: | ||
self.assertTrue(len(result["image_assembler"]) > 0) | ||
self.assertTrue(len(result["answer"]) > 0) |
Oops, something went wrong.