Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build(schema.prisma): add new sso_user_id to LiteLLM_UserTable #8167

Merged
merged 4 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,7 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
user_role: Optional[str] = None
organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None
teams: List[str] = []
sso_user_id: Optional[str] = None

@model_validator(mode="before")
@classmethod
Expand Down
54 changes: 52 additions & 2 deletions litellm/proxy/auth/auth_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CallInfo,
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
LiteLLM_OrganizationMembershipTable,
LiteLLM_OrganizationTable,
LiteLLM_TeamTable,
LiteLLM_TeamTableCachedObj,
Expand Down Expand Up @@ -425,14 +426,55 @@ def get_role_based_models(
return None


async def _get_fuzzy_user_object(
prisma_client: PrismaClient,
sso_user_id: Optional[str] = None,
user_email: Optional[str] = None,
) -> Optional[LiteLLM_UserTable]:
"""
Checks if sso user is in db.

Called when user id match is not found in db.

- Check if sso_user_id is user_id in db
- Check if sso_user_id is sso_user_id in db
- Check if user_email is user_email in db
- If not, create new user with user_email and sso_user_id and user_id = sso_user_id
"""
response = None
if sso_user_id is not None:
response = await prisma_client.db.litellm_usertable.find_unique(
where={"sso_user_id": sso_user_id},
include={"organization_memberships": True},
)

if response is None and user_email is not None:
response = await prisma_client.db.litellm_usertable.find_first(
where={"user_email": user_email},
include={"organization_memberships": True},
)

if response is not None and sso_user_id is not None: # update sso_user_id
asyncio.create_task( # background task to update user with sso id
prisma_client.db.litellm_usertable.update(
where={"user_id": response.user_id},
data={"sso_user_id": sso_user_id},
)
)

return response


@log_db_metrics
async def get_user_object(
user_id: str,
user_id: Optional[str],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
user_id_upsert: bool,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
sso_user_id: Optional[str] = None,
user_email: Optional[str] = None,
) -> Optional[LiteLLM_UserTable]:
"""
- Check if user id in proxy User Table
Expand Down Expand Up @@ -465,6 +507,14 @@ async def get_user_object(
response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}, include={"organization_memberships": True}
)

if response is None:
response = await _get_fuzzy_user_object(
prisma_client=prisma_client,
sso_user_id=sso_user_id,
user_email=user_email,
)

else:
response = None

Expand All @@ -483,7 +533,7 @@ async def get_user_object(
):
# dump each organization membership to type LiteLLM_OrganizationMembershipTable
_dumped_memberships = [
membership.model_dump()
LiteLLM_OrganizationMembershipTable(**membership.model_dump())
for membership in response.organization_memberships
if membership is not None
]
Expand Down
6 changes: 5 additions & 1 deletion litellm/proxy/auth/handle_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ async def get_user_info(
@staticmethod
async def get_objects(
user_id: Optional[str],
user_email: Optional[str],
org_id: Optional[str],
end_user_id: Optional[str],
valid_user_email: Optional[bool],
Expand Down Expand Up @@ -661,6 +662,8 @@ async def get_objects(
),
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
user_email=user_email,
sso_user_id=user_id,
)
if user_id
else None
Expand Down Expand Up @@ -704,7 +707,7 @@ async def auth_builder(

# Get basic user info
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
user_id, _, valid_user_email = await JWTAuthManager.get_user_info(
user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info(
jwt_handler, jwt_valid_token
)

Expand Down Expand Up @@ -748,6 +751,7 @@ async def auth_builder(
# Get other objects
user_object, org_object, end_user_object = await JWTAuthManager.get_objects(
user_id=user_id,
user_email=user_email,
org_id=org_id,
end_user_id=end_user_id,
valid_user_email=valid_user_email,
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ model LiteLLM_UserTable {
user_id String @id
user_alias String?
team_id String?
sso_user_id String? @unique
organization_id String?
password String?
teams String[] @default([])
Expand Down
85 changes: 85 additions & 0 deletions tests/proxy_unit_tests/test_auth_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,88 @@ async def test_can_user_call_model():

args["model"] = "gpt-3.5-turbo"
await can_user_call_model(**args)


@pytest.mark.asyncio
async def test_get_fuzzy_user_object():
from litellm.proxy.auth.auth_checks import _get_fuzzy_user_object
from litellm.proxy.utils import PrismaClient
from unittest.mock import AsyncMock, MagicMock

# Setup mock Prisma client
mock_prisma = MagicMock()
mock_prisma.db = MagicMock()
mock_prisma.db.litellm_usertable = MagicMock()

# Mock user data
test_user = LiteLLM_UserTable(
user_id="test_123",
sso_user_id="sso_123",
user_email="test@example.com",
organization_memberships=[],
max_budget=None,
)

# Test 1: Find user by SSO ID
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user)
result = await _get_fuzzy_user_object(
prisma_client=mock_prisma, sso_user_id="sso_123", user_email="test@example.com"
)
assert result == test_user
mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
)

# Test 2: SSO ID not found, find by email
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user)
mock_prisma.db.litellm_usertable.update = AsyncMock()

result = await _get_fuzzy_user_object(
prisma_client=mock_prisma,
sso_user_id="new_sso_456",
user_email="test@example.com",
)
assert result == test_user
mock_prisma.db.litellm_usertable.find_first.assert_called_with(
where={"user_email": "test@example.com"},
include={"organization_memberships": True},
)

# Test 3: Verify background SSO update task when user found by email
await asyncio.sleep(0.1) # Allow time for background task
mock_prisma.db.litellm_usertable.update.assert_called_with(
where={"user_id": "test_123"}, data={"sso_user_id": "new_sso_456"}
)

# Test 4: User not found by either method
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=None)

result = await _get_fuzzy_user_object(
prisma_client=mock_prisma,
sso_user_id="unknown_sso",
user_email="unknown@example.com",
)
assert result is None

# Test 5: Only email provided (no SSO ID)
mock_prisma.db.litellm_usertable.find_first = AsyncMock(return_value=test_user)
result = await _get_fuzzy_user_object(
prisma_client=mock_prisma, user_email="test@example.com"
)
assert result == test_user
mock_prisma.db.litellm_usertable.find_first.assert_called_with(
where={"user_email": "test@example.com"},
include={"organization_memberships": True},
)

# Test 6: Only SSO ID provided (no email)
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=test_user)
result = await _get_fuzzy_user_object(
prisma_client=mock_prisma, sso_user_id="sso_123"
)
assert result == test_user
mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
)