From b8757b8b8784fbd44c1bd96d3af6ce20a9815c5e Mon Sep 17 00:00:00 2001 From: "D. Ferruzzi" Date: Wed, 15 May 2024 11:34:32 -0700 Subject: [PATCH] Amazon Bedrock - Retrieve and RetrieveAndGenerate (#39500) Both of these calls are super fast and neither has any kind of waiter or means of checking the status, so here can not be any sensor or trigger for them. They are simple client calls, but I think making these Operators allowed us to simplify the complicated formatting on the client API call itself, for a better UX. --- airflow/providers/amazon/aws/hooks/bedrock.py | 20 ++ .../providers/amazon/aws/operators/bedrock.py | 220 +++++++++++++++++- .../providers/amazon/aws/utils/__init__.py | 7 + .../operators/bedrock.rst | 52 +++++ docs/spelling_wordlist.txt | 1 + .../amazon/aws/hooks/test_bedrock.py | 8 +- .../amazon/aws/operators/test_bedrock.py | 174 ++++++++++++++ .../aws/example_bedrock_knowledge_base.py | 70 +++++- 8 files changed, 543 insertions(+), 9 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/bedrock.py b/airflow/providers/amazon/aws/hooks/bedrock.py index b7c43be67ed85..0c2fd1a11cfc2 100644 --- a/airflow/providers/amazon/aws/hooks/bedrock.py +++ b/airflow/providers/amazon/aws/hooks/bedrock.py @@ -77,3 +77,23 @@ class BedrockAgentHook(AwsBaseHook): def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = self.client_type super().__init__(*args, **kwargs) + + +class BedrockAgentRuntimeHook(AwsBaseHook): + """ + Interact with the Amazon Agents for Bedrock API. + + Provide thin wrapper around :external+boto3:py:class:`boto3.client("bedrock-agent-runtime") `. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + client_type = "bedrock-agent-runtime" + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = self.client_type + super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py index 807adaf3938cd..3efd4ff8a5a25 100644 --- a/airflow/providers/amazon/aws/operators/bedrock.py +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -20,11 +20,17 @@ from time import sleep from typing import TYPE_CHECKING, Any, Sequence +import botocore from botocore.exceptions import ClientError from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook, BedrockRuntimeHook +from airflow.providers.amazon.aws.hooks.bedrock import ( + BedrockAgentHook, + BedrockAgentRuntimeHook, + BedrockHook, + BedrockRuntimeHook, +) from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.bedrock import ( BedrockCustomizeModelCompletedTrigger, @@ -32,7 +38,7 @@ BedrockKnowledgeBaseActiveTrigger, BedrockProvisionModelThroughputCompletedTrigger, ) -from airflow.providers.amazon.aws.utils import validate_execute_complete_event +from airflow.providers.amazon.aws.utils import get_botocore_version, validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.utils.helpers import prune_dict from airflow.utils.timezone import utcnow @@ -664,3 +670,213 @@ def execute(self, context: Context) -> str: ) return ingestion_job_id + + +class BedrockRaGOperator(AwsBaseOperator[BedrockAgentRuntimeHook]): + """ + Query a knowledge base and generate responses based on the retrieved results with sources citations. + + NOTE: Support for EXTERNAL SOURCES was added in botocore 1.34.90 + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BedrockRaGOperator` + + :param input: The query to be made to the knowledge base. (templated) + :param source_type: The type of resource that is queried by the request. (templated) + Must be one of 'KNOWLEDGE_BASE' or 'EXTERNAL_SOURCES', and the appropriate config values must also be provided. + If set to 'KNOWLEDGE_BASE' then `knowledge_base_id` must be provided, and `vector_search_config` may be. + If set to `EXTERNAL_SOURCES` then `sources` must also be provided. + NOTE: Support for EXTERNAL SOURCES was added in botocore 1.34.90 + :param model_arn: The ARN of the foundation model used to generate a response. (templated) + :param prompt_template: The template for the prompt that's sent to the model for response generation. + You can include prompt placeholders, which are replaced before the prompt is sent to the model + to provide instructions and context to the model. In addition, you can include XML tags to delineate + meaningful sections of the prompt template. (templated) + :param knowledge_base_id: The unique identifier of the knowledge base that is queried. (templated) + Can only be specified if source_type='KNOWLEDGE_BASE'. + :param vector_search_config: How the results from the vector search should be returned. (templated) + Can only be specified if source_type='KNOWLEDGE_BASE'. + For more information, see https://docs.aws.amazon.com/bedrock/latest/userguide/kb-test-config.html. + :param sources: The documents used as reference for the response. (templated) + Can only be specified if source_type='EXTERNAL_SOURCES' + NOTE: Support for EXTERNAL SOURCES was added in botocore 1.34.90 + :param rag_kwargs: Additional keyword arguments to pass to the API call. (templated) + """ + + aws_hook_class = BedrockAgentRuntimeHook + template_fields: Sequence[str] = aws_template_fields( + "input", + "source_type", + "model_arn", + "prompt_template", + "knowledge_base_id", + "vector_search_config", + "sources", + "rag_kwargs", + ) + + def __init__( + self, + input: str, + source_type: str, + model_arn: str, + prompt_template: str | None = None, + knowledge_base_id: str | None = None, + vector_search_config: dict[str, Any] | None = None, + sources: list[dict[str, Any]] | None = None, + rag_kwargs: dict[str, Any] | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.input = input + self.prompt_template = prompt_template + self.source_type = source_type.upper() + self.knowledge_base_id = knowledge_base_id + self.model_arn = model_arn + self.vector_search_config = vector_search_config + self.sources = sources + self.rag_kwargs = rag_kwargs or {} + + def validate_inputs(self): + if self.source_type == "KNOWLEDGE_BASE": + if self.knowledge_base_id is None: + raise AttributeError( + "If `source_type` is set to 'KNOWLEDGE_BASE' then `knowledge_base_id` must be provided." + ) + if self.sources is not None: + raise AttributeError( + "`sources` can not be used when `source_type` is set to 'KNOWLEDGE_BASE'." + ) + elif self.source_type == "EXTERNAL_SOURCES": + if not self.sources is not None: + raise AttributeError( + "If `source_type` is set to `EXTERNAL_SOURCES` then `sources` must also be provided." + ) + if self.vector_search_config or self.knowledge_base_id: + raise AttributeError( + "`vector_search_config` and `knowledge_base_id` can not be used " + "when `source_type` is set to `EXTERNAL_SOURCES`" + ) + else: + raise AttributeError( + "`source_type` must be one of 'KNOWLEDGE_BASE' or 'EXTERNAL_SOURCES', " + "and the appropriate config values must also be provided." + ) + + def build_rag_config(self) -> dict[str, Any]: + result: dict[str, Any] = {} + base_config: dict[str, Any] = { + "modelArn": self.model_arn, + } + + if self.prompt_template: + base_config["generationConfiguration"] = { + "promptTemplate": {"textPromptTemplate": self.prompt_template} + } + + if self.source_type == "KNOWLEDGE_BASE": + if self.vector_search_config: + base_config["retrievalConfiguration"] = { + "vectorSearchConfiguration": self.vector_search_config + } + + result = { + "type": self.source_type, + "knowledgeBaseConfiguration": { + **base_config, + "knowledgeBaseId": self.knowledge_base_id, + }, + } + + if self.source_type == "EXTERNAL_SOURCES": + result = { + "type": self.source_type, + "externalSourcesConfiguration": {**base_config, "sources": self.sources}, + } + return result + + def execute(self, context: Context) -> Any: + self.validate_inputs() + + try: + result = self.hook.conn.retrieve_and_generate( + input={"text": self.input}, + retrieveAndGenerateConfiguration=self.build_rag_config(), + **self.rag_kwargs, + ) + except botocore.exceptions.ParamValidationError as error: + if ( + 'Unknown parameter in retrieveAndGenerateConfiguration: "externalSourcesConfiguration"' + in str(error) + ) and (self.source_type == "EXTERNAL_SOURCES"): + self.log.error( + "You are attempting to use External Sources and the BOTO API returned an " + "error message which may indicate the need to update botocore to do this. \n" + "Support for External Sources was added in botocore 1.34.90 and you are using botocore %s", + ".".join(map(str, get_botocore_version())), + ) + raise + + self.log.info( + "\nPrompt: %s\nResponse: %s\nCitations: %s", + self.input, + result["output"]["text"], + result["citations"], + ) + return result + + +class BedrockRetrieveOperator(AwsBaseOperator[BedrockAgentRuntimeHook]): + """ + Query a knowledge base and retrieve results with source citations. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BedrockRetrieveOperator` + + :param retrieval_query: The query to be made to the knowledge base. (templated) + :param knowledge_base_id: The unique identifier of the knowledge base that is queried. (templated) + :param vector_search_config: How the results from the vector search should be returned. (templated) + For more information, see https://docs.aws.amazon.com/bedrock/latest/userguide/kb-test-config.html. + :param retrieve_kwargs: Additional keyword arguments to pass to the API call. (templated) + """ + + aws_hook_class = BedrockAgentRuntimeHook + template_fields: Sequence[str] = aws_template_fields( + "retrieval_query", + "knowledge_base_id", + "vector_search_config", + "retrieve_kwargs", + ) + + def __init__( + self, + retrieval_query: str, + knowledge_base_id: str, + vector_search_config: dict[str, Any] | None = None, + retrieve_kwargs: dict[str, Any] | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.retrieval_query = retrieval_query + self.knowledge_base_id = knowledge_base_id + self.vector_search_config = vector_search_config + self.retrieve_kwargs = retrieve_kwargs or {} + + def execute(self, context: Context) -> Any: + retrieval_configuration = ( + {"retrievalConfiguration": {"vectorSearchConfiguration": self.vector_search_config}} + if self.vector_search_config + else {} + ) + + result = self.hook.conn.retrieve( + retrievalQuery={"text": self.retrieval_query}, + knowledgeBaseId=self.knowledge_base_id, + **retrieval_configuration, + **self.retrieve_kwargs, + ) + + self.log.info("\nQuery: %s\nRetrieved: %s", self.retrieval_query, result["retrievalResults"]) + return result diff --git a/airflow/providers/amazon/aws/utils/__init__.py b/airflow/providers/amazon/aws/utils/__init__.py index 2a96c3e4478fe..218ccc5768a17 100644 --- a/airflow/providers/amazon/aws/utils/__init__.py +++ b/airflow/providers/amazon/aws/utils/__init__.py @@ -22,6 +22,8 @@ from enum import Enum from typing import Any +import importlib_metadata + from airflow.exceptions import AirflowException from airflow.utils.helpers import prune_dict from airflow.version import version @@ -74,6 +76,11 @@ def get_airflow_version() -> tuple[int, ...]: return tuple(int(x) for x in match.groups()) +def get_botocore_version() -> tuple[int, ...]: + """Return the version number of the installed botocore package in the form of a tuple[int,...].""" + return tuple(map(int, importlib_metadata.version("botocore").split(".")[:3])) + + def validate_execute_complete_event(event: dict[str, Any] | None = None) -> dict[str, Any]: if event is None: err_msg = "Trigger error: event is None" diff --git a/docs/apache-airflow-providers-amazon/operators/bedrock.rst b/docs/apache-airflow-providers-amazon/operators/bedrock.rst index d5074c319d5de..1808f5138c320 100644 --- a/docs/apache-airflow-providers-amazon/operators/bedrock.rst +++ b/docs/apache-airflow-providers-amazon/operators/bedrock.rst @@ -116,6 +116,9 @@ Create an Amazon Bedrock Knowledge Base To create an Amazon Bedrock Knowledge Base, you can use :class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockCreateKnowledgeBaseOperator`. +For more information on which models support embedding data into a vector store, see +https://docs.aws.amazon.com/bedrock/latest/userguide/knowledge-base-supported.html + .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py :language: python :dedent: 4 @@ -174,6 +177,55 @@ To add data from an Amazon S3 bucket into an Amazon Bedrock Data Source, you can :start-after: [START howto_operator_bedrock_ingest_data] :end-before: [END howto_operator_bedrock_ingest_data] +.. _howto/operator:BedrockRetrieveOperator: + +Amazon Bedrock Retrieve +======================= + +To query a knowledge base, you can use :class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockRetrieveOperator`. + +The response will only contain citations to sources that are relevant to the query. If you +would like to pass the results through an LLM in order to generate a text response, see +:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockRaGOperator` + +For more information on which models support retrieving information from a knowledge base, see +https://docs.aws.amazon.com/bedrock/latest/userguide/knowledge-base-supported.html + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bedrock_retrieve] + :end-before: [END howto_operator_bedrock_retrieve] + +.. _howto/operator:BedrockRaGOperator: + +Amazon Bedrock Retrieve and Generate (RaG) +========================================== + +To query a knowledge base or external sources and generate a text response based on the retrieved +results, you can use :class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockRaGOperator`. + +The response will contain citations to sources that are relevant to the query as well as a generated text reply. +For more information on which models support retrieving information from a knowledge base, see +https://docs.aws.amazon.com/bedrock/latest/userguide/knowledge-base-supported.html + +NOTE: Support for "external sources" was added in boto 1.34.90 + +Example using an Amazon Bedrock Knowledge Base: + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bedrock_knowledge_base_rag] + :end-before: [END howto_operator_bedrock_knowledge_base_rag] + +Example using a PDF file in an Amazon S3 Bucket: + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bedrock_external_sources_rag] + :end-before: [END howto_operator_bedrock_external_sources_rag] Sensors diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 40b744fc5d1d6..1c35408c09d50 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1305,6 +1305,7 @@ queueing quickstart quotechar rabbitmq +RaG RBAC rbac rc diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py b/tests/providers/amazon/aws/hooks/test_bedrock.py index c47b46e9a8647..da8a8d14a13d4 100644 --- a/tests/providers/amazon/aws/hooks/test_bedrock.py +++ b/tests/providers/amazon/aws/hooks/test_bedrock.py @@ -18,7 +18,12 @@ import pytest -from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook, BedrockRuntimeHook +from airflow.providers.amazon.aws.hooks.bedrock import ( + BedrockAgentHook, + BedrockAgentRuntimeHook, + BedrockHook, + BedrockRuntimeHook, +) class TestBedrockHooks: @@ -28,6 +33,7 @@ class TestBedrockHooks: pytest.param(BedrockHook(), "bedrock", id="bedrock"), pytest.param(BedrockRuntimeHook(), "bedrock-runtime", id="bedrock-runtime"), pytest.param(BedrockAgentHook(), "bedrock-agent", id="bedrock-agent"), + pytest.param(BedrockAgentRuntimeHook(), "bedrock-agent-runtime", id="bedrock-agent-runtime"), ], ) def test_bedrock_hooks(self, test_hook, service_name): diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py b/tests/providers/amazon/aws/operators/test_bedrock.py index ae25588146797..b49d09b52a5c0 100644 --- a/tests/providers/amazon/aws/operators/test_bedrock.py +++ b/tests/providers/amazon/aws/operators/test_bedrock.py @@ -33,6 +33,7 @@ BedrockCustomizeModelOperator, BedrockIngestDataOperator, BedrockInvokeModelOperator, + BedrockRaGOperator, ) if TYPE_CHECKING: @@ -346,3 +347,176 @@ def test_id_returned(self, mock_conn): result = self.operator.execute({}) assert result == self.INGESTION_JOB_ID + + +class TestBedrockRaGOperator: + VECTOR_SEARCH_CONFIG = {"filter": {"equals": {"key": "some key", "value": "some value"}}} + KNOWLEDGE_BASE_ID = "knowledge_base_id" + SOURCES = [{"sourceType": "S3", "s3Location": "bucket"}] + MODEL_ARN = "model arn" + + @pytest.mark.parametrize( + "source_type, vector_search_config, knowledge_base_id, sources, expect_success", + [ + pytest.param( + "invalid_source_type", + None, + None, + None, + False, + id="invalid_source_type", + ), + pytest.param( + "KNOWLEDGE_BASE", + VECTOR_SEARCH_CONFIG, + None, + None, + False, + id="KNOWLEDGE_BASE_without_knowledge_base_id_fails", + ), + pytest.param( + "KNOWLEDGE_BASE", + None, + KNOWLEDGE_BASE_ID, + None, + True, + id="KNOWLEDGE_BASE_passes", + ), + pytest.param( + "KNOWLEDGE_BASE", + VECTOR_SEARCH_CONFIG, + KNOWLEDGE_BASE_ID, + SOURCES, + False, + id="KNOWLEDGE_BASE_with_sources_fails", + ), + pytest.param( + "KNOWLEDGE_BASE", + VECTOR_SEARCH_CONFIG, + KNOWLEDGE_BASE_ID, + None, + True, + id="KNOWLEDGE_BASE_with_vector_config_passes", + ), + pytest.param( + "EXTERNAL_SOURCES", + VECTOR_SEARCH_CONFIG, + None, + SOURCES, + False, + id="EXTERNAL_SOURCES_with_search_config_fails", + ), + pytest.param( + "EXTERNAL_SOURCES", + None, + KNOWLEDGE_BASE_ID, + SOURCES, + False, + id="EXTERNAL_SOURCES_with_knohwledge_base_id_fails", + ), + pytest.param( + "EXTERNAL_SOURCES", + None, + None, + SOURCES, + True, + id="EXTERNAL_SOURCES_with_sources_passes", + ), + ], + ) + def test_input_validation( + self, source_type, vector_search_config, knowledge_base_id, sources, expect_success + ): + op = BedrockRaGOperator( + task_id="test_rag", + input="some text prompt", + source_type=source_type, + model_arn=self.MODEL_ARN, + knowledge_base_id=knowledge_base_id, + vector_search_config=vector_search_config, + sources=sources, + ) + + if expect_success: + op.validate_inputs() + else: + with pytest.raises(AttributeError): + op.validate_inputs() + + @pytest.mark.parametrize( + "prompt_template", + [ + pytest.param(None, id="no_prompt_template"), + pytest.param("valid template", id="prompt_template_provided"), + ], + ) + def test_knowledge_base_build_rag_config(self, prompt_template): + expected_source_type = "KNOWLEDGE_BASE" + op = BedrockRaGOperator( + task_id="test_rag", + input="some text prompt", + source_type=expected_source_type, + model_arn=self.MODEL_ARN, + knowledge_base_id=self.KNOWLEDGE_BASE_ID, + vector_search_config=self.VECTOR_SEARCH_CONFIG, + prompt_template=prompt_template, + ) + expected_config_without_template = { + "knowledgeBaseId": self.KNOWLEDGE_BASE_ID, + "modelArn": self.MODEL_ARN, + "retrievalConfiguration": {"vectorSearchConfiguration": self.VECTOR_SEARCH_CONFIG}, + } + expected_config_template = { + "generationConfiguration": {"promptTemplate": {"textPromptTemplate": prompt_template}} + } + config = op.build_rag_config() + + assert len(config.keys()) == 2 + assert config.get("knowledgeBaseConfiguration", False) + assert config["type"] == expected_source_type + + if not prompt_template: + assert config["knowledgeBaseConfiguration"] == expected_config_without_template + else: + assert config["knowledgeBaseConfiguration"] == { + **expected_config_without_template, + **expected_config_template, + } + + @pytest.mark.parametrize( + "prompt_template", + [ + pytest.param(None, id="no_prompt_template"), + pytest.param("valid template", id="prompt_template_provided"), + ], + ) + def test_external_sources_build_rag_config(self, prompt_template): + expected_source_type = "EXTERNAL_SOURCES" + op = BedrockRaGOperator( + task_id="test_rag", + input="some text prompt", + source_type=expected_source_type, + model_arn=self.MODEL_ARN, + sources=self.SOURCES, + prompt_template=prompt_template, + ) + expected_config_without_template = { + "modelArn": self.MODEL_ARN, + "sources": self.SOURCES, + } + expected_config_template = { + "generationConfiguration": {"promptTemplate": {"textPromptTemplate": prompt_template}} + } + config = op.build_rag_config() + + assert len(config.keys()) == 2 + assert config.get("externalSourcesConfiguration", False) + assert config["type"] == expected_source_type + + if not prompt_template: + assert config["externalSourcesConfiguration"] == expected_config_without_template + else: + assert config["externalSourcesConfiguration"] == { + **expected_config_without_template, + **expected_config_template, + } diff --git a/tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py b/tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py index 959d4ba2c69fa..b0a333382a44f 100644 --- a/tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py +++ b/tests/system/providers/amazon/aws/example_bedrock_knowledge_base.py @@ -34,7 +34,8 @@ ) from airflow import DAG -from airflow.decorators import task +from airflow.decorators import task, task_group +from airflow.operators.empty import EmptyOperator from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -43,6 +44,8 @@ BedrockCreateDataSourceOperator, BedrockCreateKnowledgeBaseOperator, BedrockIngestDataOperator, + BedrockRaGOperator, + BedrockRetrieveOperator, ) from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator from airflow.providers.amazon.aws.sensors.bedrock import ( @@ -52,15 +55,18 @@ from airflow.providers.amazon.aws.sensors.opensearch_serverless import ( OpenSearchServerlessCollectionActiveSensor, ) +from airflow.providers.amazon.aws.utils import get_botocore_version +from airflow.utils.edgemodifier import Label from airflow.utils.helpers import chain from airflow.utils.trigger_rule import TriggerRule from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder -############################################################################### -# NOTE: The account running this test must first manually request access to -# the `Titan Embeddings G1 - Text` foundation model via the Bedrock console. -# Gaining access to the model can take 24 hours from the time of request. -############################################################################### +########################################################################################################### +# NOTE: +# The account running this test must first manually request access to the `Titan Embeddings G1 - Text` +# and `Anthropic Claude v2.0` foundation models via the Bedrock console. Gaining access to the models +# can take 24 hours from the time of request. +########################################################################################################### # Externally fetched variables: ROLE_ARN_KEY = "ROLE_ARN" @@ -71,6 +77,37 @@ log = logging.getLogger(__name__) +@task_group +def external_sources_rag_group(): + """External Sources were added in boto 1.34.90, skip this operator if the version is below that.""" + + # [START howto_operator_bedrock_external_sources_rag] + external_sources_rag = BedrockRaGOperator( + task_id="external_sources_rag", + input="Who was the CEO of Amazon in 2022?", + source_type="EXTERNAL_SOURCES", + model_arn=f"arn:aws:bedrock:{region_name}::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0", + sources=[ + { + "sourceType": "S3", + "s3Location": {"uri": f"s3://{bucket_name}/AMZN-2022-Shareholder-Letter.pdf"}, + } + ], + ) + # [END howto_operator_bedrock_external_sources_rag] + + @task.branch + def run_or_skip(): + log.info("Found botocore version %s.", botocore_version := get_botocore_version()) + return end_workflow.task_id if botocore_version < (1, 34, 90) else external_sources_rag.task_id + + run_or_skip = run_or_skip() + end_workflow = EmptyOperator(task_id="end_workflow", trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS) + + chain(run_or_skip, Label("Boto version does not support External Sources"), end_workflow) + chain(run_or_skip, external_sources_rag, end_workflow) + + @task def create_opensearch_policies(bedrock_role_arn: str, collection_name: str, policy_name_suffix: str) -> None: """ @@ -480,6 +517,24 @@ def delete_opensearch_policies(collection_name: str): ) # [END howto_sensor_bedrock_ingest_data] + # [START howto_operator_bedrock_knowledge_base_rag] + knowledge_base_rag = BedrockRaGOperator( + task_id="knowledge_base_rag", + input="Who was the CEO of Amazon on 2022?", + source_type="KNOWLEDGE_BASE", + model_arn=f"arn:aws:bedrock:{region_name}::foundation-model/anthropic.claude-v2", + knowledge_base_id=create_knowledge_base.output, + ) + # [END howto_operator_bedrock_knowledge_base_rag] + + # [START howto_operator_bedrock_retrieve] + retrieve = BedrockRetrieveOperator( + task_id="retrieve", + knowledge_base_id=create_knowledge_base.output, + retrieval_query="Who was the CEO of Amazon in 1997?", + ) + # [END howto_operator_bedrock_retrieve] + delete_bucket = S3DeleteBucketOperator( task_id="delete_bucket", trigger_rule=TriggerRule.ALL_DONE, @@ -502,6 +557,9 @@ def delete_opensearch_policies(collection_name: str): create_data_source, ingest_data, await_ingest, + knowledge_base_rag, + external_sources_rag_group(), + retrieve, delete_data_source( knowledge_base_id=create_knowledge_base.output, data_source_id=create_data_source.output,