Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved wildcard route handling on /models and /model_group/info #8473

Merged
merged 12 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions litellm/integrations/pagerduty/pagerduty.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_ti
user_api_key_user_id=_meta.get("user_api_key_user_id"),
user_api_key_team_alias=_meta.get("user_api_key_team_alias"),
user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"),
user_api_key_user_email=_meta.get("user_api_key_user_email"),
)
)

Expand Down Expand Up @@ -195,6 +196,7 @@ async def hanging_response_handler(
user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
user_api_key_user_email=user_api_key_dict.user_email,
)
)

Expand Down
3 changes: 3 additions & 0 deletions litellm/integrations/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
team=user_api_team,
team_alias=user_api_team_alias,
user=user_id,
user_email=standard_logging_payload["metadata"]["user_api_key_user_email"],
status_code="200",
model=model,
litellm_model_name=model,
Expand Down Expand Up @@ -806,6 +807,7 @@ async def async_post_call_failure_hook(
enum_values = UserAPIKeyLabelValues(
end_user=user_api_key_dict.end_user_id,
user=user_api_key_dict.user_id,
user_email=user_api_key_dict.user_email,
hashed_api_key=user_api_key_dict.api_key,
api_key_alias=user_api_key_dict.key_alias,
team=user_api_key_dict.team_id,
Expand Down Expand Up @@ -853,6 +855,7 @@ async def async_post_call_success_hook(
team=user_api_key_dict.team_id,
team_alias=user_api_key_dict.team_alias,
user=user_api_key_dict.user_id,
user_email=user_api_key_dict.user_email,
status_code="200",
)
_labels = prometheus_label_factory(
Expand Down
2 changes: 2 additions & 0 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2894,6 +2894,7 @@ def get_standard_logging_metadata(
user_api_key_org_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
user_api_key_user_email=None,
spend_logs_metadata=None,
requester_ip_address=None,
requester_metadata=None,
Expand Down Expand Up @@ -3328,6 +3329,7 @@ def get_standard_logging_metadata(
user_api_key_team_id=None,
user_api_key_org_id=None,
user_api_key_user_id=None,
user_api_key_user_email=None,
user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
Expand Down
37 changes: 14 additions & 23 deletions litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ model_list:
- model_name: gpt-4
litellm_params:
model: gpt-3.5-turbo
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: azure-gpt-35-turbo
litellm_params:
model: azure/chatgpt-v-2
Expand Down Expand Up @@ -33,28 +38,14 @@ model_list:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: vertex_ai/gemini-*
litellm_params:
model: vertex_ai/gemini-*
- model_name: fake-azure-endpoint
litellm_params:
model: openai/429
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app

litellm_settings:
cache: true

general_settings:
enable_jwt_auth: True
forward_openai_org_id: True
litellm_jwtauth:
user_id_jwt_field: "sub"
team_ids_jwt_field: "groups"
user_id_upsert: true # add user_id to the db if they don't exist
enforce_team_based_model_access: true # don't allow users to access models unless the team has access

router_settings:
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT

guardrails:
- guardrail_name: "aporia-pre-guard"
litellm_params:
guardrail: aporia # supported values: "aporia", "lakera"
mode: "during_call"
api_key: os.environ/APORIO_API_KEY
api_base: os.environ/APORIO_API_BASE
callbacks: ["prometheus"]
1 change: 1 addition & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,7 @@ class UserAPIKeyAuth(
tpm_limit_per_model: Optional[Dict[str, int]] = None
user_tpm_limit: Optional[int] = None
user_rpm_limit: Optional[int] = None
user_email: Optional[str] = None

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down
38 changes: 26 additions & 12 deletions litellm/proxy/auth/model_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,8 @@ def _check_wildcard_routing(model: str) -> bool:
- openai/*
- *
"""
if model == "*":
if "*" in model:
return True

if "/" in model:
llm_provider, potential_wildcard = model.split("/", 1)
if (
llm_provider in litellm.provider_list and potential_wildcard == "*"
): # e.g. anthropic/*
return True

return False


Expand Down Expand Up @@ -156,6 +148,28 @@ def get_complete_model_list(
return list(unique_models) + all_wildcard_models


def get_known_models_from_wildcard(wildcard_model: str) -> List[str]:
try:
provider, model = wildcard_model.split("/", 1)
except ValueError: # safely fail
return []
# get all known provider models
wildcard_models = get_provider_models(provider=provider)
if wildcard_models is None:
return []
if model == "*":
return wildcard_models or []
else:
model_prefix = model.replace("*", "")
filtered_wildcard_models = [
wc_model
for wc_model in wildcard_models
if wc_model.split("/")[1].startswith(model_prefix)
]

return filtered_wildcard_models


def _get_wildcard_models(
unique_models: Set[str], return_wildcard_routes: Optional[bool] = False
) -> List[str]:
Expand All @@ -165,13 +179,13 @@ def _get_wildcard_models(
if _check_wildcard_routing(model=model):

if (
return_wildcard_routes is True
return_wildcard_routes
): # will add the wildcard route to the list eg: anthropic/*.
all_wildcard_models.append(model)

provider = model.split("/")[0]
# get all known provider models
wildcard_models = get_provider_models(provider=provider)
wildcard_models = get_known_models_from_wildcard(wildcard_model=model)

if wildcard_models is not None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a note: wildcard_models will never be None now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point

models_to_remove.add(model)
all_wildcard_models.extend(wildcard_models)
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/auth/user_api_key_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ async def _return_user_api_key_auth_obj(
user_api_key_kwargs.update(
user_tpm_limit=user_obj.tpm_limit,
user_rpm_limit=user_obj.rpm_limit,
user_email=user_obj.user_email,
)
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
user_api_key_kwargs.update(
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/litellm_pre_call_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def get_sanitized_user_information_from_key(
user_api_key_org_id=user_api_key_dict.org_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
user_api_key_user_email=user_api_key_dict.user_email,
)
return user_api_key_logged_metadata

Expand Down
6 changes: 6 additions & 0 deletions litellm/types/integrations/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
class UserAPIKeyLabelNames(Enum):
END_USER = "end_user"
USER = "user"
USER_EMAIL = "user_email"
API_KEY_HASH = "hashed_api_key"
API_KEY_ALIAS = "api_key_alias"
TEAM = "team"
Expand Down Expand Up @@ -123,6 +124,7 @@ class PrometheusMetricLabels:
UserAPIKeyLabelNames.TEAM_ALIAS.value,
UserAPIKeyLabelNames.USER.value,
UserAPIKeyLabelNames.STATUS_CODE.value,
UserAPIKeyLabelNames.USER_EMAIL.value,
]

litellm_proxy_failed_requests_metric = [
Expand Down Expand Up @@ -156,6 +158,7 @@ class PrometheusMetricLabels:
UserAPIKeyLabelNames.TEAM.value,
UserAPIKeyLabelNames.TEAM_ALIAS.value,
UserAPIKeyLabelNames.USER.value,
UserAPIKeyLabelNames.USER_EMAIL.value,
]

litellm_input_tokens_metric = [
Expand Down Expand Up @@ -240,6 +243,9 @@ class UserAPIKeyLabelValues(BaseModel):
user: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.USER.value)
] = None
user_email: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.USER_EMAIL.value)
] = None
hashed_api_key: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.API_KEY_HASH.value)
] = None
Expand Down
1 change: 1 addition & 0 deletions litellm/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,7 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
user_api_key_org_id: Optional[str]
user_api_key_team_id: Optional[str]
user_api_key_user_id: Optional[str]
user_api_key_user_email: Optional[str]
user_api_key_team_alias: Optional[str]
user_api_key_end_user_id: Optional[str]

Expand Down
1 change: 1 addition & 0 deletions tests/logging_callback_tests/test_otel_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def validate_redacted_message_span_attributes(span):
"metadata.user_api_key_user_id",
"metadata.user_api_key_org_id",
"metadata.user_api_key_end_user_id",
"metadata.user_api_key_user_email",
"metadata.applied_guardrails",
]

Expand Down
4 changes: 4 additions & 0 deletions tests/logging_callback_tests/test_prometheus_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create_standard_logging_payload() -> StandardLoggingPayload:
user_api_key_alias="test_alias",
user_api_key_team_id="test_team",
user_api_key_user_id="test_user",
user_api_key_user_email="test@example.com",
user_api_key_team_alias="test_team_alias",
user_api_key_org_id=None,
spend_logs_metadata=None,
Expand Down Expand Up @@ -475,6 +476,7 @@ def test_increment_top_level_request_and_spend_metrics(prometheus_logger):
team="test_team",
team_alias="test_team_alias",
model="gpt-3.5-turbo",
user_email=None,
)
prometheus_logger.litellm_requests_metric.labels().inc.assert_called_once()

Expand Down Expand Up @@ -631,6 +633,7 @@ async def test_async_post_call_failure_hook(prometheus_logger):
team_alias="test_team_alias",
user="test_user",
status_code="429",
user_email=None,
)
prometheus_logger.litellm_proxy_total_requests_metric.labels().inc.assert_called_once()

Expand Down Expand Up @@ -674,6 +677,7 @@ async def test_async_post_call_success_hook(prometheus_logger):
team_alias="test_team_alias",
user="test_user",
status_code="200",
user_email=None,
)
prometheus_logger.litellm_proxy_total_requests_metric.labels().inc.assert_called_once()

Expand Down
59 changes: 56 additions & 3 deletions tests/otel_tests/test_prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ async def test_proxy_failure_metrics():

assert (
expected_metric in metrics
), "Expected failure metric not found in /metrics"
expected_llm_deployment_failure = 'litellm_deployment_failure_responses_total{api_base="https://exampleopenaiendpoint-production.up.railway.app",api_provider="openai",exception_class="RateLimitError",exception_status="429",litellm_model_name="429",model_id="7499d31f98cd518cf54486d5a00deda6894239ce16d13543398dc8abf870b15f",requested_model="fake-azure-endpoint"} 1.0'
), "Expected failure metric not found in /metrics."
expected_llm_deployment_failure = 'litellm_deployment_failure_responses_total{api_key_alias="None",end_user="None",hashed_api_key="88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",requested_model="fake-azure-endpoint",status_code="429",team="None",team_alias="None",user="default_user_id",user_email="None"} 1.0'
assert expected_llm_deployment_failure

assert (
'litellm_proxy_total_requests_metric_total{api_key_alias="None",end_user="None",hashed_api_key="88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",requested_model="fake-azure-endpoint",status_code="429",team="None",team_alias="None",user="default_user_id"} 1.0'
'litellm_proxy_total_requests_metric_total{api_key_alias="None",end_user="None",hashed_api_key="88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",requested_model="fake-azure-endpoint",status_code="429",team="None",team_alias="None",user="default_user_id",user_email="None"} 1.0'
in metrics
)

Expand Down Expand Up @@ -258,6 +258,24 @@ async def create_test_team(
return team_info["team_id"]


async def create_test_user(
session: aiohttp.ClientSession, user_data: Dict[str, Any]
) -> str:
"""Create a new user and return the user_id"""
url = "http://0.0.0.0:4000/user/new"
headers = {
"Authorization": "Bearer sk-1234",
"Content-Type": "application/json",
}

async with session.post(url, headers=headers, json=user_data) as response:
assert (
response.status == 200
), f"Failed to create user. Status: {response.status}"
user_info = await response.json()
return user_info


async def get_prometheus_metrics(session: aiohttp.ClientSession) -> str:
"""Fetch current prometheus metrics"""
async with session.get("http://0.0.0.0:4000/metrics") as response:
Expand Down Expand Up @@ -526,3 +544,38 @@ async def test_key_budget_metrics():
assert (
abs(key_info_remaining_budget - first_budget["remaining"]) <= 0.00000
), f"Spend mismatch: Prometheus={key_info_remaining_budget}, Key Info={first_budget['remaining']}"


@pytest.mark.asyncio
async def test_user_email_metrics():
"""
Test user email tracking metrics:
1. Create a user with user_email
2. Make chat completion requests using OpenAI SDK with the user's email
3. Verify user email is being tracked correctly in `litellm_user_email_metric`
"""
async with aiohttp.ClientSession() as session:
# Create a user with user_email
user_data = {
"user_email": "test@example.com",
}
user_info = await create_test_user(session, user_data)
key = user_info["key"]

# Initialize OpenAI client with the user's email
client = AsyncOpenAI(base_url="http://0.0.0.0:4000", api_key=key)

# Make initial request and check budget
await client.chat.completions.create(
model="fake-openai-endpoint",
messages=[{"role": "user", "content": f"Hello {uuid.uuid4()}"}],
)

await asyncio.sleep(11) # Wait for metrics to update

# Get metrics after request
metrics_after_first = await get_prometheus_metrics(session)
print("metrics_after_first request", metrics_after_first)
assert (
"test@example.com" in metrics_after_first
), "user_email should be tracked correctly"
27 changes: 26 additions & 1 deletion tests/proxy_unit_tests/test_proxy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,7 +1618,31 @@ def test_provider_specific_header():
},
}


@pytest.mark.parametrize(
"wildcard_model, expected_models",
[
(
"anthropic/*",
["anthropic/claude-3-5-haiku-20241022", "anthropic/claude-3-opus-20240229"],
),
(
"vertex_ai/gemini-*",
["vertex_ai/gemini-1.5-flash", "vertex_ai/gemini-1.5-pro"],
),
],
)
def test_get_known_models_from_wildcard(wildcard_model, expected_models):
from litellm.proxy.auth.model_checks import get_known_models_from_wildcard

wildcard_models = get_known_models_from_wildcard(wildcard_model=wildcard_model)
# Check if all expected models are in the returned list
print(f"wildcard_models: {wildcard_models}\n")
for model in expected_models:
if model not in wildcard_models:
print(f"Missing expected model: {model}")

assert all(model in wildcard_models for model in expected_models)

@pytest.mark.parametrize(
"data, user_api_key_dict, expected_model",
[
Expand Down Expand Up @@ -1667,3 +1691,4 @@ def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_mod

# Check if model was updated correctly
assert test_data.get("model") == expected_model

Loading