Skip to content

Commit

Permalink
LiteLLM Minor Fixes & Improvements (01/16/2025) - p2 (#7828)
Browse files Browse the repository at this point in the history
* fix(vertex_ai/gemini/transformation.py): handle 'http://' image urls

* test: add base test for `http:` url's

* fix(factory.py/get_image_details): follow redirects

allows http calls to work

* fix(codestral/): fix stream chunk parsing on last chunk of stream

* Azure ad token provider (#6917)

* Update azure.py

Added optional parameter azure ad token provider

* Added parameter to main.py

* Found token provider arg location

* Fixed embeddings

* Fixed ad token provider

---------

Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com>

* fix: fix linting errors

* fix(main.py): leave out o1 route for azure ad token provider, for now

get v0 out for sync azure gpt route to begin with

* test: skip http:// test for fireworks ai

model does not support it

* refactor: cleanup dead code

* fix: revert http:// url passthrough for gemini

google ai studio raises errors

* test: fix test

---------

Co-authored-by: bahtman <anton@baht.dk>
  • Loading branch information
krrishdholakia and bahtman authored Feb 3, 2025
1 parent 10d3da7 commit 97b8de1
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 27 deletions.
47 changes: 43 additions & 4 deletions litellm/llms/azure/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import os
import time
from typing import Any, Callable, List, Literal, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Union

import httpx # type: ignore
from openai import AsyncAzureOpenAI, AzureOpenAI
Expand Down Expand Up @@ -217,7 +217,7 @@ class AzureChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()

def validate_environment(self, api_key, azure_ad_token):
def validate_environment(self, api_key, azure_ad_token, azure_ad_token_provider):
headers = {
"content-type": "application/json",
}
Expand All @@ -227,6 +227,10 @@ def validate_environment(self, api_key, azure_ad_token):
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
headers["Authorization"] = f"Bearer {azure_ad_token}"
elif azure_ad_token_provider is not None:
azure_ad_token = azure_ad_token_provider()
headers["Authorization"] = f"Bearer {azure_ad_token}"

return headers

def _get_sync_azure_client(
Expand All @@ -235,14 +239,15 @@ def _get_sync_azure_client(
api_base: Optional[str],
api_key: Optional[str],
azure_ad_token: Optional[str],
azure_ad_token_provider: Optional[Callable],
model: str,
max_retries: int,
timeout: Union[float, httpx.Timeout],
client: Optional[Any],
client_type: Literal["sync", "async"],
):
# init AzureOpenAI Client
azure_client_params = {
azure_client_params: Dict[str, Any] = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
Expand All @@ -259,6 +264,8 @@ def _get_sync_azure_client(
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None:
if client_type == "sync":
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
Expand Down Expand Up @@ -326,6 +333,7 @@ def completion( # noqa: PLR0915
api_version: str,
api_type: str,
azure_ad_token: str,
azure_ad_token_provider: Callable,
dynamic_params: bool,
print_verbose: Callable,
timeout: Union[float, httpx.Timeout],
Expand Down Expand Up @@ -373,6 +381,10 @@ def completion( # noqa: PLR0915
)

azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = (
azure_ad_token_provider
)

if acompletion is True:
client = AsyncAzureOpenAI(**azure_client_params)
Expand Down Expand Up @@ -400,6 +412,7 @@ def completion( # noqa: PLR0915
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
timeout=timeout,
client=client,
)
Expand All @@ -412,6 +425,7 @@ def completion( # noqa: PLR0915
api_version=api_version,
model=model,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
dynamic_params=dynamic_params,
timeout=timeout,
client=client,
Expand All @@ -428,6 +442,7 @@ def completion( # noqa: PLR0915
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
timeout=timeout,
client=client,
)
Expand Down Expand Up @@ -468,6 +483,10 @@ def completion( # noqa: PLR0915
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = (
azure_ad_token_provider
)

if (
client is None
Expand Down Expand Up @@ -535,6 +554,7 @@ async def acompletion(
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
convert_tool_call_to_json_mode: Optional[bool] = None,
client=None, # this is the AsyncAzureOpenAI
):
Expand Down Expand Up @@ -564,6 +584,8 @@ async def acompletion(
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider

# setting Azure client
if client is None or dynamic_params:
Expand Down Expand Up @@ -650,6 +672,7 @@ def streaming(
model: str,
timeout: Any,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None,
):
max_retries = data.pop("max_retries", 2)
Expand All @@ -675,6 +698,8 @@ def streaming(
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider

if client is None or dynamic_params:
azure_client = AzureOpenAI(**azure_client_params)
Expand Down Expand Up @@ -718,6 +743,7 @@ async def async_streaming(
model: str,
timeout: Any,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None,
):
try:
Expand All @@ -739,6 +765,8 @@ async def async_streaming(
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None or dynamic_params:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
Expand Down Expand Up @@ -844,6 +872,7 @@ def embedding(
optional_params: dict,
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
max_retries: Optional[int] = None,
client=None,
aembedding=None,
Expand Down Expand Up @@ -883,6 +912,8 @@ def embedding(
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider

## LOGGING
logging_obj.pre_call(
Expand Down Expand Up @@ -1240,6 +1271,7 @@ def image_generation(
api_version: Optional[str] = None,
model_response: Optional[ImageResponse] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None,
aimg_generation=None,
) -> ImageResponse:
Expand All @@ -1266,7 +1298,7 @@ def image_generation(
)

# init AzureOpenAI Client
azure_client_params = {
azure_client_params: Dict[str, Any] = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
Expand All @@ -1282,6 +1314,8 @@ def image_generation(
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider

if aimg_generation is True:
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
Expand Down Expand Up @@ -1342,6 +1376,7 @@ def audio_speech(
max_retries: int,
timeout: Union[float, httpx.Timeout],
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
aspeech: Optional[bool] = None,
client=None,
) -> HttpxBinaryResponseContent:
Expand All @@ -1358,6 +1393,7 @@ def audio_speech(
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
max_retries=max_retries,
timeout=timeout,
client=client,
Expand All @@ -1368,6 +1404,7 @@ def audio_speech(
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model=model,
max_retries=max_retries,
timeout=timeout,
Expand All @@ -1393,6 +1430,7 @@ async def async_audio_speech(
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
azure_ad_token_provider: Optional[Callable],
max_retries: int,
timeout: Union[float, httpx.Timeout],
client=None,
Expand All @@ -1403,6 +1441,7 @@ async def async_audio_speech(
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model=model,
max_retries=max_retries,
timeout=timeout,
Expand Down
2 changes: 2 additions & 0 deletions litellm/llms/azure/completion/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def completion( # noqa: PLR0915
api_version: str,
api_type: str,
azure_ad_token: str,
azure_ad_token_provider: Optional[Callable],
print_verbose: Callable,
timeout,
logging_obj,
Expand Down Expand Up @@ -170,6 +171,7 @@ def completion( # noqa: PLR0915
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
"azure_ad_token_provider": azure_ad_token_provider,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
Expand Down
12 changes: 11 additions & 1 deletion litellm/llms/codestral/completion/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from litellm.types.llms.databricks import GenericStreamingChunk


class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
"""
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
Expand Down Expand Up @@ -77,6 +78,7 @@ def map_openai_params(
return optional_params

def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:

text = ""
is_finished = False
finish_reason = None
Expand All @@ -90,7 +92,15 @@ def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
"is_finished": is_finished,
"finish_reason": finish_reason,
}
chunk_data_dict = json.loads(chunk_data)
try:
chunk_data_dict = json.loads(chunk_data)
except json.JSONDecodeError:
return {
"text": "",
"is_finished": is_finished,
"finish_reason": finish_reason,
}

original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True)
_choices = chunk_data_dict.get("choices", []) or []
_choice = _choices[0]
Expand Down
17 changes: 17 additions & 0 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,10 @@ def completion( # type: ignore # noqa: PLR0915
"azure_ad_token", None
) or get_secret("AZURE_AD_TOKEN")

azure_ad_token_provider = litellm_params.get(
"azure_ad_token_provider", None
)

headers = headers or litellm.headers

if extra_headers is not None:
Expand Down Expand Up @@ -1269,6 +1273,7 @@ def completion( # type: ignore # noqa: PLR0915
api_type=api_type,
dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
Expand Down Expand Up @@ -1314,6 +1319,10 @@ def completion( # type: ignore # noqa: PLR0915
"azure_ad_token", None
) or get_secret("AZURE_AD_TOKEN")

azure_ad_token_provider = litellm_params.get(
"azure_ad_token_provider", None
)

headers = headers or litellm.headers

if extra_headers is not None:
Expand All @@ -1337,6 +1346,7 @@ def completion( # type: ignore # noqa: PLR0915
api_version=api_version,
api_type=api_type,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
Expand Down Expand Up @@ -3244,6 +3254,7 @@ def embedding( # noqa: PLR0915
cooldown_time = kwargs.get("cooldown_time", None)
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
azure_ad_token_provider = kwargs.pop("azure_ad_token_provider", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
Expand Down Expand Up @@ -3374,6 +3385,7 @@ def embedding( # noqa: PLR0915
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
Expand Down Expand Up @@ -4449,6 +4461,7 @@ def image_generation( # noqa: PLR0915
logger_fn = kwargs.get("logger_fn", None)
mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore
proxy_server_request = kwargs.get("proxy_server_request", None)
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
Expand Down Expand Up @@ -4562,6 +4575,8 @@ def image_generation( # noqa: PLR0915
timeout=timeout,
api_key=api_key,
api_base=api_base,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
Expand Down Expand Up @@ -5251,6 +5266,7 @@ def speech(
) or get_secret(
"AZURE_AD_TOKEN"
)
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)

if extra_headers:
optional_params["extra_headers"] = extra_headers
Expand All @@ -5264,6 +5280,7 @@ def speech(
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
max_retries=max_retries,
timeout=timeout,
Expand Down
10 changes: 1 addition & 9 deletions litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,4 @@ model_list:


litellm_settings:
callbacks: ["langsmith"]
disable_no_log_param: true

general_settings:
enable_jwt_auth: True
litellm_jwtauth:
user_id_jwt_field: "sub"
user_email_jwt_field: "email"
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
callbacks: ["langsmith"]
Loading

0 comments on commit 97b8de1

Please sign in to comment.