Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Test] Proxy - Unit Test proxy key gen #1478

Merged
merged 11 commits into from
Jan 17, 2024
2 changes: 1 addition & 1 deletion litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ async def user_api_key_auth(
# Token exists but is expired.
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="expired user key",
detail=f"Authentication Error - Expired Key. Key Expiry time {expiry_time} and current time {current_time}",
)

# Token passed all checks
Expand Down
13 changes: 11 additions & 2 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException, status
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime


def print_verbose(print_statement):
Expand Down Expand Up @@ -375,13 +377,14 @@ async def get_data(
print_verbose(f"PrismaClient: response={response}")
if response is not None:
# for prisma we need to cast the expires time to str
response.expires = response.expires.isoformat()
if isinstance(response.expires, datetime):
response.expires = response.expires.isoformat()
return response
else:
# Token does not exist.
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid user key",
detail="Authentication Error: invalid user key - token does not exist",
)
elif user_id is not None:
response = await self.db.litellm_usertable.find_unique( # type: ignore
Expand Down Expand Up @@ -559,7 +562,13 @@ async def delete_data(self, tokens: List):
)
async def connect(self):
try:
verbose_proxy_logger.debug(
"PrismaClient: connect() called Attempting to Connect to DB"
)
if self.db.is_connected() == False:
verbose_proxy_logger.debug(
"PrismaClient: DB not connected, Attempting to Connect to DB"
)
await self.db.connect()
except Exception as e:
asyncio.create_task(
Expand Down
40 changes: 24 additions & 16 deletions litellm/tests/test_key_generate_dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,6 @@
from starlette.datastructures import URL


db_args = {
"ssl_verify": False,
"billing_mode": "PAY_PER_REQUEST",
"region_name": "us-west-2",
}
custom_db_client = DBClient(
custom_db_type="dynamo_db",
custom_db_args=db_args,
)

request_data = {
"model": "azure-gpt-3.5",
"messages": [
Expand All @@ -50,7 +40,25 @@
}


def test_generate_and_call_with_valid_key():
@pytest.fixture
def custom_db_client():
# Assuming DBClient is a class that needs to be instantiated
db_args = {
"ssl_verify": False,
"billing_mode": "PAY_PER_REQUEST",
"region_name": "us-west-2",
}
custom_db_client = DBClient(
custom_db_type="dynamo_db",
custom_db_args=db_args,
)
# Reset litellm.proxy.proxy_server.prisma_client to None
litellm.proxy.proxy_server.prisma_client = None

return custom_db_client


def test_generate_and_call_with_valid_key(custom_db_client):
# 1. Generate a Key, and use it to make a call
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
Expand All @@ -76,7 +84,7 @@ async def test():
pytest.fail(f"An exception occurred - {str(e)}")


def test_call_with_invalid_key():
def test_call_with_invalid_key(custom_db_client):
# 2. Make a call with invalid key, expect it to fail
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
Expand All @@ -101,7 +109,7 @@ async def test():
pass


def test_call_with_invalid_model():
def test_call_with_invalid_model(custom_db_client):
# 3. Make a call to a key with an invalid model - expect to fail
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
Expand Down Expand Up @@ -136,7 +144,7 @@ async def return_body():
pass


def test_call_with_valid_model():
def test_call_with_valid_model(custom_db_client):
# 4. Make a call to a key with a valid model - expect to pass
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
Expand Down Expand Up @@ -167,7 +175,7 @@ async def return_body():
pytest.fail(f"An exception occurred - {str(e)}")


def test_call_with_key_over_budget():
def test_call_with_key_over_budget(custom_db_client):
# 5. Make a call with a key over budget, expect to fail
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
Expand Down Expand Up @@ -233,7 +241,7 @@ async def test():
print(vars(e))


def test_call_with_key_over_budget_stream():
def test_call_with_key_over_budget_stream(custom_db_client):
# 6. Make a call with a key over budget, expect to fail
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
Expand Down
Loading