Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IAM cred caching for OIDC flow #3712

Merged
136 changes: 84 additions & 52 deletions litellm/llms/bedrock_httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@
ChatCompletionToolCallFunctionChunk,
ChatCompletionDeltaChunk,
)
from litellm.caching import DualCache

iam_cache = DualCache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch to in memory cache

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InMemoryCache -

class InMemoryCache(BaseCache):


class AmazonCohereChatConfig:
"""
Expand Down Expand Up @@ -325,38 +327,53 @@ def get_credentials(
) = params_to_check

### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
oidc_token = get_secret(aws_web_identity_token)
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment explaining when you'd enter this flow

iam_creds_cache_key = json.dumps({
"aws_web_identity_token": aws_web_identity_token,
"aws_role_name": aws_role_name,
"aws_session_name": aws_session_name,
"aws_region_name": aws_region_name,
})

iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
if iam_creds_dict is None:
oidc_token = get_secret(aws_web_identity_token)

if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)

if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com"
)

sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)

# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
iam_creds_dict = {
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}

session = boto3.Session(
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=aws_region_name,
)
iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have ttl be an enum at the top of the file, so it's easier to know


return session.get_credentials()
session = boto3.Session(**iam_creds_dict)

iam_creds = session.get_credentials()

return iam_creds
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
Expand Down Expand Up @@ -1416,38 +1433,53 @@ def get_credentials(
) = params_to_check

### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
oidc_token = get_secret(aws_web_identity_token)
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
iam_creds_cache_key = json.dumps({
"aws_web_identity_token": aws_web_identity_token,
"aws_role_name": aws_role_name,
"aws_session_name": aws_session_name,
"aws_region_name": aws_region_name,
})

iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
if iam_creds_dict is None:
oidc_token = get_secret(aws_web_identity_token)

if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)

if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com"
)

sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)

# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
iam_creds_dict = {
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}

session = boto3.Session(
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=aws_region_name,
)
iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60)

return session.get_credentials()
session = boto3.Session(**iam_creds_dict)

iam_creds = session.get_credentials()

return iam_creds
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
Expand Down
42 changes: 37 additions & 5 deletions litellm/tests/test_bedrock_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,13 @@ def test_completion_bedrock_claude_sts_oidc_auth():
aws_web_identity_token = "oidc/circleci_v2/"
aws_region_name = os.environ["AWS_REGION_NAME"]
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
# TODO: This is using David's IAM role, we should use Litellm's IAM role eventually
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"

try:
litellm.set_verbose = True

response = completion(
response_1 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=10,
Expand All @@ -236,8 +236,40 @@ def test_completion_bedrock_claude_sts_oidc_auth():
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
# Add any assertions here to check the response
print(response)
print(response_1)
assert len(response_1.choices) > 0
assert len(response_1.choices[0].message.content) > 0

# This second call is to verify that the cache isn't breaking anything
response_2 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=5,
temperature=0.2,
aws_region_name=aws_region_name,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
print(response_2)
assert len(response_2.choices) > 0
assert len(response_2.choices[0].message.content) > 0

# This third call is to verify that the cache isn't used for a different region
response_3 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=6,
temperature=0.3,
aws_region_name="us-east-1",
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
print(response_3)
assert len(response_3.choices) > 0
assert len(response_3.choices[0].message.content) > 0

except RateLimitError:
pass
except Exception as e:
Expand All @@ -255,7 +287,7 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth():
aws_web_identity_token = "oidc/circleci_v2/"
aws_region_name = os.environ["AWS_REGION_NAME"]
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
# TODO: This is using David's IAM role, we should use Litellm's IAM role eventually
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"

try:
Expand Down