Skip to content

Commit

Permalink
Internal User Endpoint - vulnerability fix + response type fix (#8228)
Browse files Browse the repository at this point in the history
* fix(key_management_endpoints.py): fix vulnerability where a user could update another user's keys

Resolves #8031

* test(key_management_endpoints.py): return consistent 403 forbidden error when modifying key that doesn't belong to user

* fix(internal_user_endpoints.py): return model max budget in internal user create response

Fixes #7047

* test: fix test

* test: update test to handle gemini token counter change

* fix(factory.py): fix bedrock http:// handling

* docs: fix typo in lm_studio.md (#8222)

* test: fix testing

* test: fix test

---------

Co-authored-by: foreign-sub <51928805+foreign-sub@users.noreply.github.com>
  • Loading branch information
krrishdholakia and foreign-sub authored Feb 4, 2025
1 parent f6bd48a commit df93deb
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 28 deletions.
1 change: 1 addition & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ class NewUserResponse(GenerateKeyResponse):
] = None
teams: Optional[list] = None
user_alias: Optional[str] = None
model_max_budget: Optional[dict] = None


class UpdateUserRequest(GenerateRequestBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ async def new_user(
data_json = data.json() # type: ignore
data_json = _update_internal_new_user_params(data_json, data)
response = await generate_key_helper_fn(request_type="user", **data_json)

# Admin UI Logic
# Add User to Team and Organization
# if team_id passed add this user to the team
Expand Down Expand Up @@ -220,6 +219,7 @@ async def new_user(
tpm_limit=response.get("tpm_limit", None),
rpm_limit=response.get("rpm_limit", None),
budget_duration=response.get("budget_duration", None),
model_max_budget=response.get("model_max_budget", None),
)


Expand Down
75 changes: 55 additions & 20 deletions litellm/proxy/management_endpoints/key_management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _get_user_in_team(
return None


def _is_allowed_to_create_key(
def _is_allowed_to_make_key_request(
user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], team_id: Optional[str]
) -> bool:
"""
Expand Down Expand Up @@ -266,6 +266,40 @@ def key_generation_check(
)


def common_key_access_checks(
user_api_key_dict: UserAPIKeyAuth,
data: Union[GenerateKeyRequest, UpdateKeyRequest],
llm_router: Optional[Router],
premium_user: bool,
) -> Literal[True]:
"""
Check if user is allowed to make a key request, for this key
"""
try:
_is_allowed_to_make_key_request(
user_api_key_dict=user_api_key_dict,
user_id=data.user_id,
team_id=data.team_id,
)
except AssertionError as e:
raise HTTPException(
status_code=403,
detail=str(e),
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e),
)

_check_model_access_group(
models=data.models,
llm_router=llm_router,
premium_user=premium_user,
)
return True


router = APIRouter()


Expand Down Expand Up @@ -381,25 +415,9 @@ async def generate_key_fn( # noqa: PLR0915
data=data,
)

try:
_is_allowed_to_create_key(
user_api_key_dict=user_api_key_dict,
user_id=data.user_id,
team_id=data.team_id,
)
except AssertionError as e:
raise HTTPException(
status_code=403,
detail=str(e),
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e),
)

_check_model_access_group(
models=data.models,
common_key_access_checks(
user_api_key_dict=user_api_key_dict,
data=data,
llm_router=llm_router,
premium_user=premium_user,
)
Expand Down Expand Up @@ -684,6 +702,8 @@ async def update_key_fn(
```
"""
from litellm.proxy.proxy_server import (
llm_router,
premium_user,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
Expand All @@ -692,10 +712,18 @@ async def update_key_fn(
try:
data_json: dict = data.model_dump(exclude_unset=True, exclude_none=True)
key = data_json.pop("key")

# get the row from db
if prisma_client is None:
raise Exception("Not connected to DB!")

common_key_access_checks(
user_api_key_dict=user_api_key_dict,
data=data,
llm_router=llm_router,
premium_user=premium_user,
)

existing_key_row = await prisma_client.get_data(
token=data.key, table_name="key", query_type="find_unique"
)
Expand Down Expand Up @@ -1412,6 +1440,13 @@ async def _delete_key(key: LiteLLM_VerificationToken):
):
await prisma_client.delete_data(tokens=[key.token])
deleted_tokens.append(key.token)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "You are not authorized to delete this key"
},
)

tasks.append(_delete_key(key))
await asyncio.gather(*tasks)
Expand Down
10 changes: 9 additions & 1 deletion tests/proxy_admin_ui_tests/test_key_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,15 @@ async def test_key_update_with_model_specific_params(prisma_client):
"litellm_budget_table": None,
"token": token_hash,
}
await update_key_fn(request=request, data=UpdateKeyRequest(**args))
await update_key_fn(
request=request,
data=UpdateKeyRequest(**args),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)


@pytest.mark.asyncio
Expand Down
32 changes: 30 additions & 2 deletions tests/proxy_unit_tests/test_key_generate_prisma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,11 @@ async def test():
budget_duration="1mo",
max_budget=100,
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)

print("response1=", response1)
Expand All @@ -1322,6 +1327,11 @@ async def test():
response2 = await update_key_fn(
request=Request,
data=UpdateKeyRequest(key=generated_key, team_id=_team_2),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print("response2=", response2)

Expand Down Expand Up @@ -2956,7 +2966,11 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
_request = Request(scope={"type": "http"})
_request._url = URL(url="/update/key")

await update_key_fn(data=request, request=_request)
await update_key_fn(
data=request,
request=_request,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
result = await info_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
Expand Down Expand Up @@ -3017,7 +3031,11 @@ async def test_generate_key_with_guardrails(prisma_client):
_request = Request(scope={"type": "http"})
_request._url = URL(url="/update/key")

await update_key_fn(data=request, request=_request)
await update_key_fn(
data=request,
request=_request,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
result = await info_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
Expand Down Expand Up @@ -3710,6 +3728,11 @@ async def test_key_alias_uniqueness(prisma_client):
await update_key_fn(
data=UpdateKeyRequest(key=key3.key, key_alias=unique_alias),
request=Request(scope={"type": "http"}),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
pytest.fail("Should not be able to update a key to use an existing alias")
except Exception as e:
Expand All @@ -3719,6 +3742,11 @@ async def test_key_alias_uniqueness(prisma_client):
updated_key = await update_key_fn(
data=UpdateKeyRequest(key=key1.key, key_alias=unique_alias),
request=Request(scope={"type": "http"}),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
assert updated_key is not None

Expand Down
9 changes: 5 additions & 4 deletions tests/proxy_unit_tests/test_proxy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,14 +1216,14 @@ def test_litellm_verification_token_view_response_with_budget_table(
)


def test_is_allowed_to_create_key():
def test_is_allowed_to_make_key_request():
from litellm.proxy._types import LitellmUserRoles
from litellm.proxy.management_endpoints.key_management_endpoints import (
_is_allowed_to_create_key,
_is_allowed_to_make_key_request,
)

assert (
_is_allowed_to_create_key(
_is_allowed_to_make_key_request(
user_api_key_dict=UserAPIKeyAuth(
user_id="test_user_id", user_role=LitellmUserRoles.PROXY_ADMIN
),
Expand All @@ -1234,7 +1234,7 @@ def test_is_allowed_to_create_key():
)

assert (
_is_allowed_to_create_key(
_is_allowed_to_make_key_request(
user_api_key_dict=UserAPIKeyAuth(
user_id="test_user_id",
user_role=LitellmUserRoles.INTERNAL_USER,
Expand Down Expand Up @@ -1553,6 +1553,7 @@ async def test_spend_logs_cleanup_after_error():
mock_client.spend_log_transactions == original_logs[100:]
), "Should remove processed logs even after error"


def test_provider_specific_header():
from litellm.proxy.litellm_pre_call_utils import (
add_provider_specific_headers_to_request,
Expand Down
Loading

0 comments on commit df93deb

Please sign in to comment.