Skip to content

Commit

Permalink
Improved wildcard route handling on /models and /model_group/info (
Browse files Browse the repository at this point in the history
…#8473)

* fix(model_checks.py): update returning known model from wildcard to filter based on given model prefix

ensures wildcard route - `vertex_ai/gemini-*` just returns known vertex_ai/gemini- models

* test(test_proxy_utils.py): add unit testing for new 'get_known_models_from_wildcard' helper

* test(test_models.py): add e2e testing for `/model_group/info` endpoint

* feat(prometheus.py): support tracking total requests by user_email on prometheus

adds initial support for tracking total requests by user_email

* test(test_prometheus.py): add testing to ensure user email is always tracked

* test: update testing for new prometheus metric

* test(test_prometheus_unit_tests.py): add user email to total proxy metric

* test: update tests

* test: fix spend tests

* test: fix test

* fix(pagerduty.py): fix linting error
  • Loading branch information
krrishdholakia committed Feb 13, 2025
1 parent 1852c44 commit b945a11
Show file tree
Hide file tree
Showing 15 changed files with 191 additions and 39 deletions.
2 changes: 2 additions & 0 deletions litellm/integrations/pagerduty/pagerduty.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_ti
user_api_key_user_id=_meta.get("user_api_key_user_id"),
user_api_key_team_alias=_meta.get("user_api_key_team_alias"),
user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"),
user_api_key_user_email=_meta.get("user_api_key_user_email"),
)
)

Expand Down Expand Up @@ -195,6 +196,7 @@ async def hanging_response_handler(
user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
user_api_key_user_email=user_api_key_dict.user_email,
)
)

Expand Down
3 changes: 3 additions & 0 deletions litellm/integrations/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
team=user_api_team,
team_alias=user_api_team_alias,
user=user_id,
user_email=standard_logging_payload["metadata"]["user_api_key_user_email"],
status_code="200",
model=model,
litellm_model_name=model,
Expand Down Expand Up @@ -806,6 +807,7 @@ async def async_post_call_failure_hook(
enum_values = UserAPIKeyLabelValues(
end_user=user_api_key_dict.end_user_id,
user=user_api_key_dict.user_id,
user_email=user_api_key_dict.user_email,
hashed_api_key=user_api_key_dict.api_key,
api_key_alias=user_api_key_dict.key_alias,
team=user_api_key_dict.team_id,
Expand Down Expand Up @@ -853,6 +855,7 @@ async def async_post_call_success_hook(
team=user_api_key_dict.team_id,
team_alias=user_api_key_dict.team_alias,
user=user_api_key_dict.user_id,
user_email=user_api_key_dict.user_email,
status_code="200",
)
_labels = prometheus_label_factory(
Expand Down
2 changes: 2 additions & 0 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2894,6 +2894,7 @@ def get_standard_logging_metadata(
user_api_key_org_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
user_api_key_user_email=None,
spend_logs_metadata=None,
requester_ip_address=None,
requester_metadata=None,
Expand Down Expand Up @@ -3328,6 +3329,7 @@ def get_standard_logging_metadata(
user_api_key_team_id=None,
user_api_key_org_id=None,
user_api_key_user_id=None,
user_api_key_user_email=None,
user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
Expand Down
37 changes: 14 additions & 23 deletions litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ model_list:
- model_name: gpt-4
litellm_params:
model: gpt-3.5-turbo
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: azure-gpt-35-turbo
litellm_params:
model: azure/chatgpt-v-2
Expand Down Expand Up @@ -33,28 +38,14 @@ model_list:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: vertex_ai/gemini-*
litellm_params:
model: vertex_ai/gemini-*
- model_name: fake-azure-endpoint
litellm_params:
model: openai/429
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app

litellm_settings:
cache: true

general_settings:
enable_jwt_auth: True
forward_openai_org_id: True
litellm_jwtauth:
user_id_jwt_field: "sub"
team_ids_jwt_field: "groups"
user_id_upsert: true # add user_id to the db if they don't exist
enforce_team_based_model_access: true # don't allow users to access models unless the team has access

router_settings:
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT

guardrails:
- guardrail_name: "aporia-pre-guard"
litellm_params:
guardrail: aporia # supported values: "aporia", "lakera"
mode: "during_call"
api_key: os.environ/APORIO_API_KEY
api_base: os.environ/APORIO_API_BASE
callbacks: ["prometheus"]
1 change: 1 addition & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,7 @@ class UserAPIKeyAuth(
tpm_limit_per_model: Optional[Dict[str, int]] = None
user_tpm_limit: Optional[int] = None
user_rpm_limit: Optional[int] = None
user_email: Optional[str] = None

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down
38 changes: 26 additions & 12 deletions litellm/proxy/auth/model_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,8 @@ def _check_wildcard_routing(model: str) -> bool:
- openai/*
- *
"""
if model == "*":
if "*" in model:
return True

if "/" in model:
llm_provider, potential_wildcard = model.split("/", 1)
if (
llm_provider in litellm.provider_list and potential_wildcard == "*"
): # e.g. anthropic/*
return True

return False


Expand Down Expand Up @@ -156,6 +148,28 @@ def get_complete_model_list(
return list(unique_models) + all_wildcard_models


def get_known_models_from_wildcard(wildcard_model: str) -> List[str]:
try:
provider, model = wildcard_model.split("/", 1)
except ValueError: # safely fail
return []
# get all known provider models
wildcard_models = get_provider_models(provider=provider)
if wildcard_models is None:
return []
if model == "*":
return wildcard_models or []
else:
model_prefix = model.replace("*", "")
filtered_wildcard_models = [
wc_model
for wc_model in wildcard_models
if wc_model.split("/")[1].startswith(model_prefix)
]

return filtered_wildcard_models


def _get_wildcard_models(
unique_models: Set[str], return_wildcard_routes: Optional[bool] = False
) -> List[str]:
Expand All @@ -165,13 +179,13 @@ def _get_wildcard_models(
if _check_wildcard_routing(model=model):

if (
return_wildcard_routes is True
return_wildcard_routes
): # will add the wildcard route to the list eg: anthropic/*.
all_wildcard_models.append(model)

provider = model.split("/")[0]
# get all known provider models
wildcard_models = get_provider_models(provider=provider)
wildcard_models = get_known_models_from_wildcard(wildcard_model=model)

if wildcard_models is not None:
models_to_remove.add(model)
all_wildcard_models.extend(wildcard_models)
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/auth/user_api_key_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ async def _return_user_api_key_auth_obj(
user_api_key_kwargs.update(
user_tpm_limit=user_obj.tpm_limit,
user_rpm_limit=user_obj.rpm_limit,
user_email=user_obj.user_email,
)
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
user_api_key_kwargs.update(
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/litellm_pre_call_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def get_sanitized_user_information_from_key(
user_api_key_org_id=user_api_key_dict.org_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
user_api_key_user_email=user_api_key_dict.user_email,
)
return user_api_key_logged_metadata

Expand Down
6 changes: 6 additions & 0 deletions litellm/types/integrations/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
class UserAPIKeyLabelNames(Enum):
END_USER = "end_user"
USER = "user"
USER_EMAIL = "user_email"
API_KEY_HASH = "hashed_api_key"
API_KEY_ALIAS = "api_key_alias"
TEAM = "team"
Expand Down Expand Up @@ -123,6 +124,7 @@ class PrometheusMetricLabels:
UserAPIKeyLabelNames.TEAM_ALIAS.value,
UserAPIKeyLabelNames.USER.value,
UserAPIKeyLabelNames.STATUS_CODE.value,
UserAPIKeyLabelNames.USER_EMAIL.value,
]

litellm_proxy_failed_requests_metric = [
Expand Down Expand Up @@ -156,6 +158,7 @@ class PrometheusMetricLabels:
UserAPIKeyLabelNames.TEAM.value,
UserAPIKeyLabelNames.TEAM_ALIAS.value,
UserAPIKeyLabelNames.USER.value,
UserAPIKeyLabelNames.USER_EMAIL.value,
]

litellm_input_tokens_metric = [
Expand Down Expand Up @@ -240,6 +243,9 @@ class UserAPIKeyLabelValues(BaseModel):
user: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.USER.value)
] = None
user_email: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.USER_EMAIL.value)
] = None
hashed_api_key: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.API_KEY_HASH.value)
] = None
Expand Down
1 change: 1 addition & 0 deletions litellm/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,7 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
user_api_key_org_id: Optional[str]
user_api_key_team_id: Optional[str]
user_api_key_user_id: Optional[str]
user_api_key_user_email: Optional[str]
user_api_key_team_alias: Optional[str]
user_api_key_end_user_id: Optional[str]

Expand Down
1 change: 1 addition & 0 deletions tests/logging_callback_tests/test_otel_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def validate_redacted_message_span_attributes(span):
"metadata.user_api_key_user_id",
"metadata.user_api_key_org_id",
"metadata.user_api_key_end_user_id",
"metadata.user_api_key_user_email",
"metadata.applied_guardrails",
]

Expand Down
4 changes: 4 additions & 0 deletions tests/logging_callback_tests/test_prometheus_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create_standard_logging_payload() -> StandardLoggingPayload:
user_api_key_alias="test_alias",
user_api_key_team_id="test_team",
user_api_key_user_id="test_user",
user_api_key_user_email="test@example.com",
user_api_key_team_alias="test_team_alias",
user_api_key_org_id=None,
spend_logs_metadata=None,
Expand Down Expand Up @@ -475,6 +476,7 @@ def test_increment_top_level_request_and_spend_metrics(prometheus_logger):
team="test_team",
team_alias="test_team_alias",
model="gpt-3.5-turbo",
user_email=None,
)
prometheus_logger.litellm_requests_metric.labels().inc.assert_called_once()

Expand Down Expand Up @@ -631,6 +633,7 @@ async def test_async_post_call_failure_hook(prometheus_logger):
team_alias="test_team_alias",
user="test_user",
status_code="429",
user_email=None,
)
prometheus_logger.litellm_proxy_total_requests_metric.labels().inc.assert_called_once()

Expand Down Expand Up @@ -674,6 +677,7 @@ async def test_async_post_call_success_hook(prometheus_logger):
team_alias="test_team_alias",
user="test_user",
status_code="200",
user_email=None,
)
prometheus_logger.litellm_proxy_total_requests_metric.labels().inc.assert_called_once()

Expand Down
59 changes: 56 additions & 3 deletions tests/otel_tests/test_prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ async def test_proxy_failure_metrics():

assert (
expected_metric in metrics
), "Expected failure metric not found in /metrics"
expected_llm_deployment_failure = 'litellm_deployment_failure_responses_total{api_base="https://exampleopenaiendpoint-production.up.railway.app",api_provider="openai",exception_class="RateLimitError",exception_status="429",litellm_model_name="429",model_id="7499d31f98cd518cf54486d5a00deda6894239ce16d13543398dc8abf870b15f",requested_model="fake-azure-endpoint"} 1.0'
), "Expected failure metric not found in /metrics."
expected_llm_deployment_failure = 'litellm_deployment_failure_responses_total{api_key_alias="None",end_user="None",hashed_api_key="88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",requested_model="fake-azure-endpoint",status_code="429",team="None",team_alias="None",user="default_user_id",user_email="None"} 1.0'
assert expected_llm_deployment_failure

assert (
'litellm_proxy_total_requests_metric_total{api_key_alias="None",end_user="None",hashed_api_key="88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",requested_model="fake-azure-endpoint",status_code="429",team="None",team_alias="None",user="default_user_id"} 1.0'
'litellm_proxy_total_requests_metric_total{api_key_alias="None",end_user="None",hashed_api_key="88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",requested_model="fake-azure-endpoint",status_code="429",team="None",team_alias="None",user="default_user_id",user_email="None"} 1.0'
in metrics
)

Expand Down Expand Up @@ -258,6 +258,24 @@ async def create_test_team(
return team_info["team_id"]


async def create_test_user(
session: aiohttp.ClientSession, user_data: Dict[str, Any]
) -> str:
"""Create a new user and return the user_id"""
url = "http://0.0.0.0:4000/user/new"
headers = {
"Authorization": "Bearer sk-1234",
"Content-Type": "application/json",
}

async with session.post(url, headers=headers, json=user_data) as response:
assert (
response.status == 200
), f"Failed to create user. Status: {response.status}"
user_info = await response.json()
return user_info


async def get_prometheus_metrics(session: aiohttp.ClientSession) -> str:
"""Fetch current prometheus metrics"""
async with session.get("http://0.0.0.0:4000/metrics") as response:
Expand Down Expand Up @@ -526,3 +544,38 @@ async def test_key_budget_metrics():
assert (
abs(key_info_remaining_budget - first_budget["remaining"]) <= 0.00000
), f"Spend mismatch: Prometheus={key_info_remaining_budget}, Key Info={first_budget['remaining']}"


@pytest.mark.asyncio
async def test_user_email_metrics():
"""
Test user email tracking metrics:
1. Create a user with user_email
2. Make chat completion requests using OpenAI SDK with the user's email
3. Verify user email is being tracked correctly in `litellm_user_email_metric`
"""
async with aiohttp.ClientSession() as session:
# Create a user with user_email
user_data = {
"user_email": "test@example.com",
}
user_info = await create_test_user(session, user_data)
key = user_info["key"]

# Initialize OpenAI client with the user's email
client = AsyncOpenAI(base_url="http://0.0.0.0:4000", api_key=key)

# Make initial request and check budget
await client.chat.completions.create(
model="fake-openai-endpoint",
messages=[{"role": "user", "content": f"Hello {uuid.uuid4()}"}],
)

await asyncio.sleep(11) # Wait for metrics to update

# Get metrics after request
metrics_after_first = await get_prometheus_metrics(session)
print("metrics_after_first request", metrics_after_first)
assert (
"test@example.com" in metrics_after_first
), "user_email should be tracked correctly"
27 changes: 26 additions & 1 deletion tests/proxy_unit_tests/test_proxy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,7 +1618,31 @@ def test_provider_specific_header():
},
}


@pytest.mark.parametrize(
"wildcard_model, expected_models",
[
(
"anthropic/*",
["anthropic/claude-3-5-haiku-20241022", "anthropic/claude-3-opus-20240229"],
),
(
"vertex_ai/gemini-*",
["vertex_ai/gemini-1.5-flash", "vertex_ai/gemini-1.5-pro"],
),
],
)
def test_get_known_models_from_wildcard(wildcard_model, expected_models):
from litellm.proxy.auth.model_checks import get_known_models_from_wildcard

wildcard_models = get_known_models_from_wildcard(wildcard_model=wildcard_model)
# Check if all expected models are in the returned list
print(f"wildcard_models: {wildcard_models}\n")
for model in expected_models:
if model not in wildcard_models:
print(f"Missing expected model: {model}")

assert all(model in wildcard_models for model in expected_models)

@pytest.mark.parametrize(
"data, user_api_key_dict, expected_model",
[
Expand Down Expand Up @@ -1667,3 +1691,4 @@ def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_mod

# Check if model was updated correctly
assert test_data.get("model") == expected_model

Loading

0 comments on commit b945a11

Please sign in to comment.