diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index b011d9512952..84b61d4cbd1b 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -53,7 +53,9 @@ ChatCompletionToolCallFunctionChunk, ChatCompletionDeltaChunk, ) +from litellm.caching import DualCache +iam_cache = DualCache() class AmazonCohereChatConfig: """ @@ -325,38 +327,53 @@ def get_credentials( ) = params_to_check ### CHECK STS ### - if ( - aws_web_identity_token is not None - and aws_role_name is not None - and aws_session_name is not None - ): - oidc_token = get_secret(aws_web_identity_token) + if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None: + iam_creds_cache_key = json.dumps({ + "aws_web_identity_token": aws_web_identity_token, + "aws_role_name": aws_role_name, + "aws_session_name": aws_session_name, + "aws_region_name": aws_region_name, + }) + + iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key) + if iam_creds_dict is None: + oidc_token = get_secret(aws_web_identity_token) + + if oidc_token is None: + raise BedrockError( + message="OIDC token could not be retrieved from secret manager.", + status_code=401, + ) - if oidc_token is None: - raise BedrockError( - message="OIDC token could not be retrieved from secret manager.", - status_code=401, + sts_client = boto3.client( + "sts", + region_name=aws_region_name, + endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com" ) - sts_client = boto3.client("sts") + # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=aws_role_name, + RoleSessionName=aws_session_name, + WebIdentityToken=oidc_token, + DurationSeconds=3600, + ) - # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html - # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html - sts_response = sts_client.assume_role_with_web_identity( - RoleArn=aws_role_name, - RoleSessionName=aws_session_name, - WebIdentityToken=oidc_token, - DurationSeconds=3600, - ) + iam_creds_dict = { + "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"], + "aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"], + "aws_session_token": sts_response["Credentials"]["SessionToken"], + "region_name": aws_region_name, + } - session = boto3.Session( - aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], - aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], - aws_session_token=sts_response["Credentials"]["SessionToken"], - region_name=aws_region_name, - ) + iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60) - return session.get_credentials() + session = boto3.Session(**iam_creds_dict) + + iam_creds = session.get_credentials() + + return iam_creds elif aws_role_name is not None and aws_session_name is not None: sts_client = boto3.client( "sts", @@ -1416,38 +1433,53 @@ def get_credentials( ) = params_to_check ### CHECK STS ### - if ( - aws_web_identity_token is not None - and aws_role_name is not None - and aws_session_name is not None - ): - oidc_token = get_secret(aws_web_identity_token) + if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None: + iam_creds_cache_key = json.dumps({ + "aws_web_identity_token": aws_web_identity_token, + "aws_role_name": aws_role_name, + "aws_session_name": aws_session_name, + "aws_region_name": aws_region_name, + }) + + iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key) + if iam_creds_dict is None: + oidc_token = get_secret(aws_web_identity_token) + + if oidc_token is None: + raise BedrockError( + message="OIDC token could not be retrieved from secret manager.", + status_code=401, + ) - if oidc_token is None: - raise BedrockError( - message="OIDC token could not be retrieved from secret manager.", - status_code=401, + sts_client = boto3.client( + "sts", + region_name=aws_region_name, + endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com" ) - sts_client = boto3.client("sts") + # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=aws_role_name, + RoleSessionName=aws_session_name, + WebIdentityToken=oidc_token, + DurationSeconds=3600, + ) - # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html - # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html - sts_response = sts_client.assume_role_with_web_identity( - RoleArn=aws_role_name, - RoleSessionName=aws_session_name, - WebIdentityToken=oidc_token, - DurationSeconds=3600, - ) + iam_creds_dict = { + "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"], + "aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"], + "aws_session_token": sts_response["Credentials"]["SessionToken"], + "region_name": aws_region_name, + } - session = boto3.Session( - aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], - aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], - aws_session_token=sts_response["Credentials"]["SessionToken"], - region_name=aws_region_name, - ) + iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60) - return session.get_credentials() + session = boto3.Session(**iam_creds_dict) + + iam_creds = session.get_credentials() + + return iam_creds elif aws_role_name is not None and aws_session_name is not None: sts_client = boto3.client( "sts", diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 64e7741e2a19..b953ca2a3a63 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -220,13 +220,13 @@ def test_completion_bedrock_claude_sts_oidc_auth(): aws_web_identity_token = "oidc/circleci_v2/" aws_region_name = os.environ["AWS_REGION_NAME"] # aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"] - # TODO: This is using David's IAM role, we should use Litellm's IAM role eventually + # TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci" try: litellm.set_verbose = True - response = completion( + response_1 = completion( model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", messages=messages, max_tokens=10, @@ -236,8 +236,40 @@ def test_completion_bedrock_claude_sts_oidc_auth(): aws_role_name=aws_role_name, aws_session_name="my-test-session", ) - # Add any assertions here to check the response - print(response) + print(response_1) + assert len(response_1.choices) > 0 + assert len(response_1.choices[0].message.content) > 0 + + # This second call is to verify that the cache isn't breaking anything + response_2 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=5, + temperature=0.2, + aws_region_name=aws_region_name, + aws_web_identity_token=aws_web_identity_token, + aws_role_name=aws_role_name, + aws_session_name="my-test-session", + ) + print(response_2) + assert len(response_2.choices) > 0 + assert len(response_2.choices[0].message.content) > 0 + + # This third call is to verify that the cache isn't used for a different region + response_3 = completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + messages=messages, + max_tokens=6, + temperature=0.3, + aws_region_name="us-east-1", + aws_web_identity_token=aws_web_identity_token, + aws_role_name=aws_role_name, + aws_session_name="my-test-session", + ) + print(response_3) + assert len(response_3.choices) > 0 + assert len(response_3.choices[0].message.content) > 0 + except RateLimitError: pass except Exception as e: @@ -255,7 +287,7 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth(): aws_web_identity_token = "oidc/circleci_v2/" aws_region_name = os.environ["AWS_REGION_NAME"] # aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"] - # TODO: This is using David's IAM role, we should use Litellm's IAM role eventually + # TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci" try: