diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index bf13d178d461..2713977878bf 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1576,6 +1576,7 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase): user_role: Optional[str] = None organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None teams: List[str] = [] + sso_user_id: Optional[str] = None @model_validator(mode="before") @classmethod diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 8d0132709c18..b28ca4cb2d92 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -28,6 +28,7 @@ CallInfo, LiteLLM_EndUserTable, LiteLLM_JWTAuth, + LiteLLM_OrganizationMembershipTable, LiteLLM_OrganizationTable, LiteLLM_TeamTable, LiteLLM_TeamTableCachedObj, @@ -425,14 +426,55 @@ def get_role_based_models( return None +async def _get_fuzzy_user_object( + prisma_client: PrismaClient, + sso_user_id: Optional[str] = None, + user_email: Optional[str] = None, +) -> Optional[LiteLLM_UserTable]: + """ + Checks if sso user is in db. + + Called when user id match is not found in db. + + - Check if sso_user_id is user_id in db + - Check if sso_user_id is sso_user_id in db + - Check if user_email is user_email in db + - If not, create new user with user_email and sso_user_id and user_id = sso_user_id + """ + response = None + if sso_user_id is not None: + response = await prisma_client.db.litellm_usertable.find_unique( + where={"sso_user_id": sso_user_id}, + include={"organization_memberships": True}, + ) + + if response is None and user_email is not None: + response = await prisma_client.db.litellm_usertable.find_first( + where={"user_email": user_email}, + include={"organization_memberships": True}, + ) + + if response is not None and sso_user_id is not None: # update sso_user_id + asyncio.create_task( # background task to update user with sso id + prisma_client.db.litellm_usertable.update( + where={"user_id": response.user_id}, + data={"sso_user_id": sso_user_id}, + ) + ) + + return response + + @log_db_metrics async def get_user_object( - user_id: str, + user_id: Optional[str], prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, user_id_upsert: bool, parent_otel_span: Optional[Span] = None, proxy_logging_obj: Optional[ProxyLogging] = None, + sso_user_id: Optional[str] = None, + user_email: Optional[str] = None, ) -> Optional[LiteLLM_UserTable]: """ - Check if user id in proxy User Table @@ -465,6 +507,14 @@ async def get_user_object( response = await prisma_client.db.litellm_usertable.find_unique( where={"user_id": user_id}, include={"organization_memberships": True} ) + + if response is None: + response = await _get_fuzzy_user_object( + prisma_client=prisma_client, + sso_user_id=sso_user_id, + user_email=user_email, + ) + else: response = None @@ -483,7 +533,7 @@ async def get_user_object( ): # dump each organization membership to type LiteLLM_OrganizationMembershipTable _dumped_memberships = [ - membership.model_dump() + LiteLLM_OrganizationMembershipTable(**membership.model_dump()) for membership in response.organization_memberships if membership is not None ] diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 09c1881f7121..2c876c856705 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -621,6 +621,7 @@ async def get_user_info( @staticmethod async def get_objects( user_id: Optional[str], + user_email: Optional[str], org_id: Optional[str], end_user_id: Optional[str], valid_user_email: Optional[bool], @@ -661,6 +662,8 @@ async def get_objects( ), parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, + user_email=user_email, + sso_user_id=user_id, ) if user_id else None @@ -704,7 +707,7 @@ async def auth_builder( # Get basic user info scopes = jwt_handler.get_scopes(token=jwt_valid_token) - user_id, _, valid_user_email = await JWTAuthManager.get_user_info( + user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info( jwt_handler, jwt_valid_token ) @@ -748,6 +751,7 @@ async def auth_builder( # Get other objects user_object, org_object, end_user_object = await JWTAuthManager.get_objects( user_id=user_id, + user_email=user_email, org_id=org_id, end_user_id=end_user_id, valid_user_email=valid_user_email, diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 2332e9c6aada..df337c94d0e3 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -103,6 +103,7 @@ model LiteLLM_UserTable { user_id String @id user_alias String? team_id String? + sso_user_id String? @unique organization_id String? password String? teams String[] @default([]) diff --git a/tests/proxy_unit_tests/test_auth_checks.py b/tests/proxy_unit_tests/test_auth_checks.py index 04af3d6e299a..ad79328ade9b 100644 --- a/tests/proxy_unit_tests/test_auth_checks.py +++ b/tests/proxy_unit_tests/test_auth_checks.py @@ -548,3 +548,88 @@ async def test_can_user_call_model(): args["model"] = "gpt-3.5-turbo" await can_user_call_model(**args) + + +@pytest.mark.asyncio +async def test_get_fuzzy_user_object(): + from litellm.proxy.auth.auth_checks import _get_fuzzy_user_object + from litellm.proxy.utils import PrismaClient + from unittest.mock import AsyncMock, MagicMock + + # Setup mock Prisma client + mock_prisma = MagicMock() + mock_prisma.db = MagicMock() + mock_prisma.db.litellm_usertable = MagicMock() + + # Mock user data + test_user = LiteLLM_UserTable( + user_id="test_123", + sso_user_id="sso_123", + user_email="test@example.com", + organization_memberships=[], + max_budget=None, + ) + + # Test 1: Find user by SSO ID + mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user) + result = await _get_fuzzy_user_object( + prisma_client=mock_prisma, sso_user_id="sso_123", user_email="test@example.com" + ) + assert result == test_user + mock_prisma.db.litellm_usertable.find_unique.assert_called_with( + where={"sso_user_id": "sso_123"}, include={"organization_memberships": True} + ) + + # Test 2: SSO ID not found, find by email + mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None) + mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user) + mock_prisma.db.litellm_usertable.update = AsyncMock() + + result = await _get_fuzzy_user_object( + prisma_client=mock_prisma, + sso_user_id="new_sso_456", + user_email="test@example.com", + ) + assert result == test_user + mock_prisma.db.litellm_usertable.find_first.assert_called_with( + where={"user_email": "test@example.com"}, + include={"organization_memberships": True}, + ) + + # Test 3: Verify background SSO update task when user found by email + await asyncio.sleep(0.1) # Allow time for background task + mock_prisma.db.litellm_usertable.update.assert_called_with( + where={"user_id": "test_123"}, data={"sso_user_id": "new_sso_456"} + ) + + # Test 4: User not found by either method + mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None) + mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=None) + + result = await _get_fuzzy_user_object( + prisma_client=mock_prisma, + sso_user_id="unknown_sso", + user_email="unknown@example.com", + ) + assert result is None + + # Test 5: Only email provided (no SSO ID) + mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user) + result = await _get_fuzzy_user_object( + prisma_client=mock_prisma, user_email="test@example.com" + ) + assert result == test_user + mock_prisma.db.litellm_usertable.find_first.assert_called_with( + where={"user_email": "test@example.com"}, + include={"organization_memberships": True}, + ) + + # Test 6: Only SSO ID provided (no email) + mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user) + result = await _get_fuzzy_user_object( + prisma_client=mock_prisma, sso_user_id="sso_123" + ) + assert result == test_user + mock_prisma.db.litellm_usertable.find_unique.assert_called_with( + where={"sso_user_id": "sso_123"}, include={"organization_memberships": True} + )