From 1129268b8706cc469fe5fb56a36cb3f3d3cfd200 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 31 Jan 2025 18:07:38 -0800 Subject: [PATCH 1/3] build(schema.prisma): add new `sso_user_id` to LiteLLM_UserTable easier way to store sso id for existing user Allows existing user added to team, to login via SSO --- litellm/proxy/_new_secret_config.yaml | 13 +++++++++++- litellm/proxy/auth/auth_checks.py | 27 ++++++++++++++++++++++++- litellm/proxy/auth/user_api_key_auth.py | 9 +++++---- litellm/proxy/schema.prisma | 1 + 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index ddf14718c98a..78eb0e12e5c7 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -3,6 +3,10 @@ model_list: litellm_params: model: gpt-3.5-turbo rpm: 3 + - model_name: o3-mini + litellm_params: + model: o3-mini + rpm: 3 - model_name: anthropic-claude litellm_params: model: claude-3-5-haiku-20241022 @@ -18,4 +22,11 @@ model_list: litellm_settings: callbacks: ["langsmith"] - disable_no_log_param: true \ No newline at end of file + disable_no_log_param: true + +general_settings: + enable_jwt_auth: True + litellm_jwtauth: + user_id_jwt_field: "sub" + user_email_jwt_field: "email" + team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD \ No newline at end of file diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 8d0132709c18..71a944de8a2c 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -427,12 +427,14 @@ def get_role_based_models( @log_db_metrics async def get_user_object( - user_id: str, + user_id: Optional[str], prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, user_id_upsert: bool, parent_otel_span: Optional[Span] = None, proxy_logging_obj: Optional[ProxyLogging] = None, + sso_user_id: Optional[str] = None, + user_email: Optional[str] = None, ) -> Optional[LiteLLM_UserTable]: """ - Check if user id in proxy User Table @@ -465,6 +467,29 @@ async def get_user_object( response = await prisma_client.db.litellm_usertable.find_unique( where={"user_id": user_id}, include={"organization_memberships": True} ) + + if response is None and sso_user_id is not None: + response = await prisma_client.db.litellm_usertable.find_unique( + where={"sso_user_id": sso_user_id}, + include={"organization_memberships": True}, + ) + + if response is None and user_email is not None: + response = await prisma_client.db.litellm_usertable.find_first( + where={"user_email": user_email}, + include={"organization_memberships": True}, + ) + + if ( + response is not None and sso_user_id is not None + ): # update sso_user_id + asyncio.create_task( + prisma_client.db.litellm_usertable.update( + where={"user_id": response.user_id}, + data={"sso_user_id": sso_user_id}, + ) + ) + else: response = None diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 7d499af5b266..1bd01c42b4ce 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -340,7 +340,9 @@ async def _jwt_auth_user_api_key_auth_builder( # [OPTIONAL] allowed user email domains valid_user_email: Optional[bool] = None - user_email: Optional[str] = None + user_email: Optional[str] = jwt_handler.get_user_email( + token=jwt_valid_token, default_value=None + ) if jwt_handler.is_enforced_email_domain(): """ if 'allowed_email_subdomains' is set, @@ -348,9 +350,6 @@ async def _jwt_auth_user_api_key_auth_builder( - checks if token contains 'email' field - checks if 'email' is from an allowed domain """ - user_email = jwt_handler.get_user_email( - token=jwt_valid_token, default_value=None - ) if user_email is None: valid_user_email = False else: @@ -449,6 +448,8 @@ async def _jwt_auth_user_api_key_auth_builder( ), parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, + user_email=user_email, + sso_user_id=user_id, ) # [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable` end_user_object = None diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 2332e9c6aada..df337c94d0e3 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -103,6 +103,7 @@ model LiteLLM_UserTable { user_id String @id user_alias String? team_id String? + sso_user_id String? @unique organization_id String? password String? teams String[] @default([]) From efaa35b5b22a1ec96dc68f4738cef57338f7a228 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 31 Jan 2025 21:43:28 -0800 Subject: [PATCH 2/3] test(test_auth_checks.py): add unit testing for fuzzy user object get --- litellm/proxy/_types.py | 1 + litellm/proxy/auth/auth_checks.py | 67 +++++++++++------ tests/proxy_unit_tests/test_auth_checks.py | 85 ++++++++++++++++++++++ 3 files changed, 132 insertions(+), 21 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index bf13d178d461..2713977878bf 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1576,6 +1576,7 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase): user_role: Optional[str] = None organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None teams: List[str] = [] + sso_user_id: Optional[str] = None @model_validator(mode="before") @classmethod diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 71a944de8a2c..b28ca4cb2d92 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -28,6 +28,7 @@ CallInfo, LiteLLM_EndUserTable, LiteLLM_JWTAuth, + LiteLLM_OrganizationMembershipTable, LiteLLM_OrganizationTable, LiteLLM_TeamTable, LiteLLM_TeamTableCachedObj, @@ -425,6 +426,45 @@ def get_role_based_models( return None +async def _get_fuzzy_user_object( + prisma_client: PrismaClient, + sso_user_id: Optional[str] = None, + user_email: Optional[str] = None, +) -> Optional[LiteLLM_UserTable]: + """ + Checks if sso user is in db. + + Called when user id match is not found in db. + + - Check if sso_user_id is user_id in db + - Check if sso_user_id is sso_user_id in db + - Check if user_email is user_email in db + - If not, create new user with user_email and sso_user_id and user_id = sso_user_id + """ + response = None + if sso_user_id is not None: + response = await prisma_client.db.litellm_usertable.find_unique( + where={"sso_user_id": sso_user_id}, + include={"organization_memberships": True}, + ) + + if response is None and user_email is not None: + response = await prisma_client.db.litellm_usertable.find_first( + where={"user_email": user_email}, + include={"organization_memberships": True}, + ) + + if response is not None and sso_user_id is not None: # update sso_user_id + asyncio.create_task( # background task to update user with sso id + prisma_client.db.litellm_usertable.update( + where={"user_id": response.user_id}, + data={"sso_user_id": sso_user_id}, + ) + ) + + return response + + @log_db_metrics async def get_user_object( user_id: Optional[str], @@ -468,28 +508,13 @@ async def get_user_object( where={"user_id": user_id}, include={"organization_memberships": True} ) - if response is None and sso_user_id is not None: - response = await prisma_client.db.litellm_usertable.find_unique( - where={"sso_user_id": sso_user_id}, - include={"organization_memberships": True}, + if response is None: + response = await _get_fuzzy_user_object( + prisma_client=prisma_client, + sso_user_id=sso_user_id, + user_email=user_email, ) - if response is None and user_email is not None: - response = await prisma_client.db.litellm_usertable.find_first( - where={"user_email": user_email}, - include={"organization_memberships": True}, - ) - - if ( - response is not None and sso_user_id is not None - ): # update sso_user_id - asyncio.create_task( - prisma_client.db.litellm_usertable.update( - where={"user_id": response.user_id}, - data={"sso_user_id": sso_user_id}, - ) - ) - else: response = None @@ -508,7 +533,7 @@ async def get_user_object( ): # dump each organization membership to type LiteLLM_OrganizationMembershipTable _dumped_memberships = [ - membership.model_dump() + LiteLLM_OrganizationMembershipTable(**membership.model_dump()) for membership in response.organization_memberships if membership is not None ] diff --git a/tests/proxy_unit_tests/test_auth_checks.py b/tests/proxy_unit_tests/test_auth_checks.py index 04af3d6e299a..ad79328ade9b 100644 --- a/tests/proxy_unit_tests/test_auth_checks.py +++ b/tests/proxy_unit_tests/test_auth_checks.py @@ -548,3 +548,88 @@ async def test_can_user_call_model(): args["model"] = "gpt-3.5-turbo" await can_user_call_model(**args) + + +@pytest.mark.asyncio +async def test_get_fuzzy_user_object(): + from litellm.proxy.auth.auth_checks import _get_fuzzy_user_object + from litellm.proxy.utils import PrismaClient + from unittest.mock import AsyncMock, MagicMock + + # Setup mock Prisma client + mock_prisma = MagicMock() + mock_prisma.db = MagicMock() + mock_prisma.db.litellm_usertable = MagicMock() + + # Mock user data + test_user = LiteLLM_UserTable( + user_id="test_123", + sso_user_id="sso_123", + user_email="test@example.com", + organization_memberships=[], + max_budget=None, + ) + + # Test 1: Find user by SSO ID + mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user) + result = await _get_fuzzy_user_object( + prisma_client=mock_prisma, sso_user_id="sso_123", user_email="test@example.com" + ) + assert result == test_user + mock_prisma.db.litellm_usertable.find_unique.assert_called_with( + where={"sso_user_id": "sso_123"}, include={"organization_memberships": True} + ) + + # Test 2: SSO ID not found, find by email + mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None) + mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user) + mock_prisma.db.litellm_usertable.update = AsyncMock() + + result = await _get_fuzzy_user_object( + prisma_client=mock_prisma, + sso_user_id="new_sso_456", + user_email="test@example.com", + ) + assert result == test_user + mock_prisma.db.litellm_usertable.find_first.assert_called_with( + where={"user_email": "test@example.com"}, + include={"organization_memberships": True}, + ) + + # Test 3: Verify background SSO update task when user found by email + await asyncio.sleep(0.1) # Allow time for background task + mock_prisma.db.litellm_usertable.update.assert_called_with( + where={"user_id": "test_123"}, data={"sso_user_id": "new_sso_456"} + ) + + # Test 4: User not found by either method + mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None) + mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=None) + + result = await _get_fuzzy_user_object( + prisma_client=mock_prisma, + sso_user_id="unknown_sso", + user_email="unknown@example.com", + ) + assert result is None + + # Test 5: Only email provided (no SSO ID) + mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user) + result = await _get_fuzzy_user_object( + prisma_client=mock_prisma, user_email="test@example.com" + ) + assert result == test_user + mock_prisma.db.litellm_usertable.find_first.assert_called_with( + where={"user_email": "test@example.com"}, + include={"organization_memberships": True}, + ) + + # Test 6: Only SSO ID provided (no email) + mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user) + result = await _get_fuzzy_user_object( + prisma_client=mock_prisma, sso_user_id="sso_123" + ) + assert result == test_user + mock_prisma.db.litellm_usertable.find_unique.assert_called_with( + where={"sso_user_id": "sso_123"}, include={"organization_memberships": True} + ) From 5a3a5f9b5757ca3c4a9ec18c0adc411bbae68b5a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 31 Jan 2025 23:02:12 -0800 Subject: [PATCH 3/3] fix(handle_jwt.py): fix merge conflicts --- litellm/proxy/auth/handle_jwt.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 09c1881f7121..2c876c856705 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -621,6 +621,7 @@ async def get_user_info( @staticmethod async def get_objects( user_id: Optional[str], + user_email: Optional[str], org_id: Optional[str], end_user_id: Optional[str], valid_user_email: Optional[bool], @@ -661,6 +662,8 @@ async def get_objects( ), parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, + user_email=user_email, + sso_user_id=user_id, ) if user_id else None @@ -704,7 +707,7 @@ async def auth_builder( # Get basic user info scopes = jwt_handler.get_scopes(token=jwt_valid_token) - user_id, _, valid_user_email = await JWTAuthManager.get_user_info( + user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info( jwt_handler, jwt_valid_token ) @@ -748,6 +751,7 @@ async def auth_builder( # Get other objects user_object, org_object, end_user_object = await JWTAuthManager.get_objects( user_id=user_id, + user_email=user_email, org_id=org_id, end_user_id=end_user_id, valid_user_email=valid_user_email,