Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved wildcard route handling on /models and /model_group/info #8473

Merged
merged 12 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 3 additions & 25 deletions litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,6 @@ model_list:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/

litellm_settings:
cache: true

general_settings:
enable_jwt_auth: True
forward_openai_org_id: True
litellm_jwtauth:
user_id_jwt_field: "sub"
team_ids_jwt_field: "groups"
user_id_upsert: true # add user_id to the db if they don't exist
enforce_team_based_model_access: true # don't allow users to access models unless the team has access

router_settings:
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT

guardrails:
- guardrail_name: "aporia-pre-guard"
litellm_params:
guardrail: aporia # supported values: "aporia", "lakera"
mode: "during_call"
api_key: os.environ/APORIO_API_KEY
api_base: os.environ/APORIO_API_BASE
- model_name: vertex_ai/gemini-*
litellm_params:
model: vertex_ai/gemini-*
38 changes: 26 additions & 12 deletions litellm/proxy/auth/model_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,8 @@ def _check_wildcard_routing(model: str) -> bool:
- openai/*
- *
"""
if model == "*":
if "*" in model:
return True

if "/" in model:
llm_provider, potential_wildcard = model.split("/", 1)
if (
llm_provider in litellm.provider_list and potential_wildcard == "*"
): # e.g. anthropic/*
return True

return False


Expand Down Expand Up @@ -156,6 +148,28 @@ def get_complete_model_list(
return list(unique_models) + all_wildcard_models


def get_known_models_from_wildcard(wildcard_model: str) -> List[str]:
try:
provider, model = wildcard_model.split("/", 1)
except ValueError: # safely fail
return []
# get all known provider models
wildcard_models = get_provider_models(provider=provider)
if wildcard_models is None:
return []
if model == "*":
return wildcard_models or []
else:
model_prefix = model.replace("*", "")
filtered_wildcard_models = [
wc_model
for wc_model in wildcard_models
if wc_model.split("/")[1].startswith(model_prefix)
]

return filtered_wildcard_models


def _get_wildcard_models(
unique_models: Set[str], return_wildcard_routes: Optional[bool] = False
) -> List[str]:
Expand All @@ -165,13 +179,13 @@ def _get_wildcard_models(
if _check_wildcard_routing(model=model):

if (
return_wildcard_routes is True
return_wildcard_routes
): # will add the wildcard route to the list eg: anthropic/*.
all_wildcard_models.append(model)

provider = model.split("/")[0]
# get all known provider models
wildcard_models = get_provider_models(provider=provider)
wildcard_models = get_known_models_from_wildcard(wildcard_model=model)

if wildcard_models is not None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a note: wildcard_models will never be None now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point

models_to_remove.add(model)
all_wildcard_models.extend(wildcard_models)
Expand Down
26 changes: 26 additions & 0 deletions tests/proxy_unit_tests/test_proxy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,3 +1617,29 @@ def test_provider_specific_header():
"anthropic-beta": "prompt-caching-2024-07-31",
},
}


@pytest.mark.parametrize(
"wildcard_model, expected_models",
[
(
"anthropic/*",
["anthropic/claude-3-5-haiku-20241022", "anthropic/claude-3-opus-20240229"],
),
(
"vertex_ai/gemini-*",
["vertex_ai/gemini-1.5-flash", "vertex_ai/gemini-1.5-pro"],
),
],
)
def test_get_known_models_from_wildcard(wildcard_model, expected_models):
from litellm.proxy.auth.model_checks import get_known_models_from_wildcard

wildcard_models = get_known_models_from_wildcard(wildcard_model=wildcard_model)
# Check if all expected models are in the returned list
print(f"wildcard_models: {wildcard_models}\n")
for model in expected_models:
if model not in wildcard_models:
print(f"Missing expected model: {model}")

assert all(model in wildcard_models for model in expected_models)
47 changes: 47 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async def get_models(session, key):

if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()


@pytest.mark.asyncio
Expand Down Expand Up @@ -112,6 +113,24 @@ async def get_model_info(session, key, litellm_model_id=None):
return await response.json()


async def get_model_group_info(session, key):
url = "http://0.0.0.0:4000/model_group/info"
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}

async with session.get(url, headers=headers) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()

if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()


async def chat_completion(session, key, model="azure-gpt-3.5"):
url = "http://0.0.0.0:4000/chat/completions"
headers = {
Expand Down Expand Up @@ -394,3 +413,31 @@ async def test_add_model_run_health():

# cleanup
await delete_model(session=session, model_id=model_id)


@pytest.mark.asyncio
async def test_model_group_info_e2e():
"""
Test /model/group/info endpoint
"""
async with aiohttp.ClientSession() as session:
models = await get_models(session=session, key="sk-1234")
print(models)

expected_models = [
"anthropic/claude-3-5-haiku-20241022",
"anthropic/claude-3-opus-20240229",
]

model_group_info = await get_model_group_info(session=session, key="sk-1234")
print(model_group_info)

has_anthropic_claude_3_5_haiku = False
has_anthropic_claude_3_opus = False
for model in model_group_info["data"]:
if model["model_group"] == "anthropic/claude-3-5-haiku-20241022":
has_anthropic_claude_3_5_haiku = True
if model["model_group"] == "anthropic/claude-3-opus-20240229":
has_anthropic_claude_3_opus = True

assert has_anthropic_claude_3_5_haiku and has_anthropic_claude_3_opus