Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(util.py/azure.py): Add OIDC support when running LiteLLM on Azure + Azure Upstream caching #3861

Merged
merged 7 commits into from
Jun 12, 2024
38 changes: 31 additions & 7 deletions litellm/llms/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from openai import AzureOpenAI, AsyncAzureOpenAI
import uuid
import os
from litellm.caching import DualCache

azure_ad_cache = DualCache()


class AzureOpenAIError(Exception):
Expand Down Expand Up @@ -136,9 +139,10 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):

def get_azure_ad_token_from_oidc(azure_ad_token: str):
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
azure_tenant = os.getenv("AZURE_TENANT_ID", None)
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
azure_authority_host = os.getenv("AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com")

if azure_client_id is None or azure_tenant is None:
if azure_client_id is None or azure_tenant_id is None:
raise AzureOpenAIError(
status_code=422,
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
Expand All @@ -152,8 +156,19 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
message="OIDC token could not be retrieved from secret manager.",
)

azure_ad_token_cache_key = json.dumps({
"azure_client_id": azure_client_id,
"azure_tenant_id": azure_tenant_id,
"azure_authority_host": azure_authority_host,
"oidc_token": oidc_token,
})

azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
if azure_ad_token_access_token is not None:
return azure_ad_token_access_token

req_token = httpx.post(
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token",
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
data={
"client_id": azure_client_id,
"grant_type": "client_credentials",
Expand All @@ -169,14 +184,23 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
message=req_token.text,
)

possible_azure_ad_token = req_token.json().get("access_token", None)
azure_ad_token_json = req_token.json()
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)

if possible_azure_ad_token is None:
if azure_ad_token_access_token is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token not returned"
status_code=422, message="Azure AD Token access_token not returned"
)

return possible_azure_ad_token
if azure_ad_token_expires_in is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token expires_in not returned"
)

azure_ad_cache.set_cache(key=azure_ad_token_cache_key, value=azure_ad_token_access_token, ttl=azure_ad_token_expires_in)

return azure_ad_token_access_token


class AzureChatCompletion(BaseLLM):
Expand Down
8 changes: 8 additions & 0 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10085,6 +10085,14 @@ def get_secret(
return oidc_token
else:
raise ValueError("Github OIDC provider failed")
elif oidc_provider == "azure":
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
if azure_federated_token_file is None:
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
with open(azure_federated_token_file, "r") as f:
oidc_token = f.read()
return oidc_token
else:
raise ValueError("Unsupported OIDC provider")

Expand Down