From 97b8de17abe57f4489517f48cd60549c008ca17c Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sun, 2 Feb 2025 23:17:50 -0800 Subject: [PATCH] LiteLLM Minor Fixes & Improvements (01/16/2025) - p2 (#7828) * fix(vertex_ai/gemini/transformation.py): handle 'http://' image urls * test: add base test for `http:` url's * fix(factory.py/get_image_details): follow redirects allows http calls to work * fix(codestral/): fix stream chunk parsing on last chunk of stream * Azure ad token provider (#6917) * Update azure.py Added optional parameter azure ad token provider * Added parameter to main.py * Found token provider arg location * Fixed embeddings * Fixed ad token provider --------- Co-authored-by: Krish Dholakia * fix: fix linting errors * fix(main.py): leave out o1 route for azure ad token provider, for now get v0 out for sync azure gpt route to begin with * test: skip http:// test for fireworks ai model does not support it * refactor: cleanup dead code * fix: revert http:// url passthrough for gemini google ai studio raises errors * test: fix test --------- Co-authored-by: bahtman --- litellm/llms/azure/azure.py | 47 +++++++++++++++++-- litellm/llms/azure/completion/handler.py | 2 + .../codestral/completion/transformation.py | 12 ++++- litellm/main.py | 17 +++++++ litellm/proxy/_new_secret_config.yaml | 10 +--- tests/llm_translation/base_llm_unit_tests.py | 15 +++++- tests/llm_translation/test_vertex.py | 5 +- .../test_anthropic_prompt_caching.py | 25 ++++++---- .../test_stream_chunk_builder.py | 1 - 9 files changed, 107 insertions(+), 27 deletions(-) diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index f771532133c1..6c578e4d8ed3 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -2,7 +2,7 @@ import json import os import time -from typing import Any, Callable, List, Literal, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Union import httpx # type: ignore from openai import AsyncAzureOpenAI, AzureOpenAI @@ -217,7 +217,7 @@ class AzureChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() - def validate_environment(self, api_key, azure_ad_token): + def validate_environment(self, api_key, azure_ad_token, azure_ad_token_provider): headers = { "content-type": "application/json", } @@ -227,6 +227,10 @@ def validate_environment(self, api_key, azure_ad_token): if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) headers["Authorization"] = f"Bearer {azure_ad_token}" + elif azure_ad_token_provider is not None: + azure_ad_token = azure_ad_token_provider() + headers["Authorization"] = f"Bearer {azure_ad_token}" + return headers def _get_sync_azure_client( @@ -235,6 +239,7 @@ def _get_sync_azure_client( api_base: Optional[str], api_key: Optional[str], azure_ad_token: Optional[str], + azure_ad_token_provider: Optional[Callable], model: str, max_retries: int, timeout: Union[float, httpx.Timeout], @@ -242,7 +247,7 @@ def _get_sync_azure_client( client_type: Literal["sync", "async"], ): # init AzureOpenAI Client - azure_client_params = { + azure_client_params: Dict[str, Any] = { "api_version": api_version, "azure_endpoint": api_base, "azure_deployment": model, @@ -259,6 +264,8 @@ def _get_sync_azure_client( if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token + elif azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider if client is None: if client_type == "sync": azure_client = AzureOpenAI(**azure_client_params) # type: ignore @@ -326,6 +333,7 @@ def completion( # noqa: PLR0915 api_version: str, api_type: str, azure_ad_token: str, + azure_ad_token_provider: Callable, dynamic_params: bool, print_verbose: Callable, timeout: Union[float, httpx.Timeout], @@ -373,6 +381,10 @@ def completion( # noqa: PLR0915 ) azure_client_params["azure_ad_token"] = azure_ad_token + elif azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = ( + azure_ad_token_provider + ) if acompletion is True: client = AsyncAzureOpenAI(**azure_client_params) @@ -400,6 +412,7 @@ def completion( # noqa: PLR0915 api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, timeout=timeout, client=client, ) @@ -412,6 +425,7 @@ def completion( # noqa: PLR0915 api_version=api_version, model=model, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, dynamic_params=dynamic_params, timeout=timeout, client=client, @@ -428,6 +442,7 @@ def completion( # noqa: PLR0915 api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, timeout=timeout, client=client, ) @@ -468,6 +483,10 @@ def completion( # noqa: PLR0915 if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token + elif azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = ( + azure_ad_token_provider + ) if ( client is None @@ -535,6 +554,7 @@ async def acompletion( model_response: ModelResponse, logging_obj: LiteLLMLoggingObj, azure_ad_token: Optional[str] = None, + azure_ad_token_provider: Optional[Callable] = None, convert_tool_call_to_json_mode: Optional[bool] = None, client=None, # this is the AsyncAzureOpenAI ): @@ -564,6 +584,8 @@ async def acompletion( if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token + elif azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider # setting Azure client if client is None or dynamic_params: @@ -650,6 +672,7 @@ def streaming( model: str, timeout: Any, azure_ad_token: Optional[str] = None, + azure_ad_token_provider: Optional[Callable] = None, client=None, ): max_retries = data.pop("max_retries", 2) @@ -675,6 +698,8 @@ def streaming( if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token + elif azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider if client is None or dynamic_params: azure_client = AzureOpenAI(**azure_client_params) @@ -718,6 +743,7 @@ async def async_streaming( model: str, timeout: Any, azure_ad_token: Optional[str] = None, + azure_ad_token_provider: Optional[Callable] = None, client=None, ): try: @@ -739,6 +765,8 @@ async def async_streaming( if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token + elif azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider if client is None or dynamic_params: azure_client = AsyncAzureOpenAI(**azure_client_params) else: @@ -844,6 +872,7 @@ def embedding( optional_params: dict, api_key: Optional[str] = None, azure_ad_token: Optional[str] = None, + azure_ad_token_provider: Optional[Callable] = None, max_retries: Optional[int] = None, client=None, aembedding=None, @@ -883,6 +912,8 @@ def embedding( if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token + elif azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider ## LOGGING logging_obj.pre_call( @@ -1240,6 +1271,7 @@ def image_generation( api_version: Optional[str] = None, model_response: Optional[ImageResponse] = None, azure_ad_token: Optional[str] = None, + azure_ad_token_provider: Optional[Callable] = None, client=None, aimg_generation=None, ) -> ImageResponse: @@ -1266,7 +1298,7 @@ def image_generation( ) # init AzureOpenAI Client - azure_client_params = { + azure_client_params: Dict[str, Any] = { "api_version": api_version, "azure_endpoint": api_base, "azure_deployment": model, @@ -1282,6 +1314,8 @@ def image_generation( if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token + elif azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider if aimg_generation is True: return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore @@ -1342,6 +1376,7 @@ def audio_speech( max_retries: int, timeout: Union[float, httpx.Timeout], azure_ad_token: Optional[str] = None, + azure_ad_token_provider: Optional[Callable] = None, aspeech: Optional[bool] = None, client=None, ) -> HttpxBinaryResponseContent: @@ -1358,6 +1393,7 @@ def audio_speech( api_base=api_base, api_version=api_version, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, max_retries=max_retries, timeout=timeout, client=client, @@ -1368,6 +1404,7 @@ def audio_speech( api_version=api_version, api_key=api_key, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, model=model, max_retries=max_retries, timeout=timeout, @@ -1393,6 +1430,7 @@ async def async_audio_speech( api_base: Optional[str], api_version: Optional[str], azure_ad_token: Optional[str], + azure_ad_token_provider: Optional[Callable], max_retries: int, timeout: Union[float, httpx.Timeout], client=None, @@ -1403,6 +1441,7 @@ async def async_audio_speech( api_version=api_version, api_key=api_key, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, model=model, max_retries=max_retries, timeout=timeout, diff --git a/litellm/llms/azure/completion/handler.py b/litellm/llms/azure/completion/handler.py index 42309bdd2359..31d634de652c 100644 --- a/litellm/llms/azure/completion/handler.py +++ b/litellm/llms/azure/completion/handler.py @@ -49,6 +49,7 @@ def completion( # noqa: PLR0915 api_version: str, api_type: str, azure_ad_token: str, + azure_ad_token_provider: Optional[Callable], print_verbose: Callable, timeout, logging_obj, @@ -170,6 +171,7 @@ def completion( # noqa: PLR0915 "http_client": litellm.client_session, "max_retries": max_retries, "timeout": timeout, + "azure_ad_token_provider": azure_ad_token_provider, } azure_client_params = select_azure_base_url_or_endpoint( azure_client_params=azure_client_params diff --git a/litellm/llms/codestral/completion/transformation.py b/litellm/llms/codestral/completion/transformation.py index 261744d88569..84551cd55309 100644 --- a/litellm/llms/codestral/completion/transformation.py +++ b/litellm/llms/codestral/completion/transformation.py @@ -5,6 +5,7 @@ from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig from litellm.types.llms.databricks import GenericStreamingChunk + class CodestralTextCompletionConfig(OpenAITextCompletionConfig): """ Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion @@ -77,6 +78,7 @@ def map_openai_params( return optional_params def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk: + text = "" is_finished = False finish_reason = None @@ -90,7 +92,15 @@ def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk: "is_finished": is_finished, "finish_reason": finish_reason, } - chunk_data_dict = json.loads(chunk_data) + try: + chunk_data_dict = json.loads(chunk_data) + except json.JSONDecodeError: + return { + "text": "", + "is_finished": is_finished, + "finish_reason": finish_reason, + } + original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True) _choices = chunk_data_dict.get("choices", []) or [] _choice = _choices[0] diff --git a/litellm/main.py b/litellm/main.py index 40dbf3b2ad6f..cc71d3133bd8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1214,6 +1214,10 @@ def completion( # type: ignore # noqa: PLR0915 "azure_ad_token", None ) or get_secret("AZURE_AD_TOKEN") + azure_ad_token_provider = litellm_params.get( + "azure_ad_token_provider", None + ) + headers = headers or litellm.headers if extra_headers is not None: @@ -1269,6 +1273,7 @@ def completion( # type: ignore # noqa: PLR0915 api_type=api_type, dynamic_params=dynamic_params, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, @@ -1314,6 +1319,10 @@ def completion( # type: ignore # noqa: PLR0915 "azure_ad_token", None ) or get_secret("AZURE_AD_TOKEN") + azure_ad_token_provider = litellm_params.get( + "azure_ad_token_provider", None + ) + headers = headers or litellm.headers if extra_headers is not None: @@ -1337,6 +1346,7 @@ def completion( # type: ignore # noqa: PLR0915 api_version=api_version, api_type=api_type, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, @@ -3244,6 +3254,7 @@ def embedding( # noqa: PLR0915 cooldown_time = kwargs.get("cooldown_time", None) mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore max_parallel_requests = kwargs.pop("max_parallel_requests", None) + azure_ad_token_provider = kwargs.pop("azure_ad_token_provider", None) model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", None) proxy_server_request = kwargs.get("proxy_server_request", None) @@ -3374,6 +3385,7 @@ def embedding( # noqa: PLR0915 api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, logging_obj=logging, timeout=timeout, model_response=EmbeddingResponse(), @@ -4449,6 +4461,7 @@ def image_generation( # noqa: PLR0915 logger_fn = kwargs.get("logger_fn", None) mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore proxy_server_request = kwargs.get("proxy_server_request", None) + azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None) model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", {}) litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore @@ -4562,6 +4575,8 @@ def image_generation( # noqa: PLR0915 timeout=timeout, api_key=api_key, api_base=api_base, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response=model_response, @@ -5251,6 +5266,7 @@ def speech( ) or get_secret( "AZURE_AD_TOKEN" ) + azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None) if extra_headers: optional_params["extra_headers"] = extra_headers @@ -5264,6 +5280,7 @@ def speech( api_base=api_base, api_version=api_version, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, organization=organization, max_retries=max_retries, timeout=timeout, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 8bdba65c39ea..fe32440e4ed0 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -28,12 +28,4 @@ model_list: litellm_settings: - callbacks: ["langsmith"] - disable_no_log_param: true - -general_settings: - enable_jwt_auth: True - litellm_jwtauth: - user_id_jwt_field: "sub" - user_email_jwt_field: "email" - team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD \ No newline at end of file + callbacks: ["langsmith"] \ No newline at end of file diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index 2704834486d2..06acdb5df910 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -461,8 +461,15 @@ def test_tool_call_no_arguments(self, tool_call_no_arguments): pass @pytest.mark.parametrize("detail", [None, "low", "high"]) + @pytest.mark.parametrize( + "image_url", + [ + "http://img1.etsystatic.com/260/0/7813604/il_fullxfull.4226713999_q86e.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + ], + ) @pytest.mark.flaky(retries=4, delay=1) - def test_image_url(self, detail): + def test_image_url(self, detail, image_url): litellm.set_verbose = True from litellm.utils import supports_vision @@ -472,6 +479,10 @@ def test_image_url(self, detail): base_completion_call_args = self.get_base_completion_call_args() if not supports_vision(base_completion_call_args["model"], None): pytest.skip("Model does not support image input") + elif "http://" in image_url and "fireworks_ai" in base_completion_call_args.get( + "model" + ): + pytest.skip("Model does not support http:// input") messages = [ { @@ -481,7 +492,7 @@ def test_image_url(self, detail): { "type": "image_url", "image_url": { - "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + "url": image_url, }, }, ], diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py index 15d2df715191..db867e520294 100644 --- a/tests/llm_translation/test_vertex.py +++ b/tests/llm_translation/test_vertex.py @@ -1289,10 +1289,11 @@ def test_process_gemini_image_http_url( http_url: Test HTTP URL mock_convert_to_anthropic: Mocked convert_to_anthropic_image_obj function mock_blob: Mocked BlobType instance + + Vertex AI supports image urls. Ensure no network requests are made. """ - # Arrange expected_image_data = "..." mock_convert_url_to_base64.return_value = expected_image_data - # Act result = _process_gemini_image(http_url) + # assert result["file_data"]["file_uri"] == http_url diff --git a/tests/local_testing/test_anthropic_prompt_caching.py b/tests/local_testing/test_anthropic_prompt_caching.py index 6919b5518647..4c2b66879e21 100644 --- a/tests/local_testing/test_anthropic_prompt_caching.py +++ b/tests/local_testing/test_anthropic_prompt_caching.py @@ -205,20 +205,29 @@ def anthropic_messages(): ] -def test_anthropic_vertex_ai_prompt_caching(anthropic_messages): +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_anthropic_vertex_ai_prompt_caching(anthropic_messages, sync_mode): litellm._turn_on_debug() - from litellm.llms.custom_httpx.http_handler import HTTPHandler + from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler load_vertex_ai_credentials() - client = HTTPHandler() + client = HTTPHandler() if sync_mode else AsyncHTTPHandler() with patch.object(client, "post", return_value=MagicMock()) as mock_post: try: - response = completion( - model="vertex_ai/claude-3-5-sonnet-v2@20241022 ", - messages=anthropic_messages, - client=client, - ) + if sync_mode: + response = completion( + model="vertex_ai/claude-3-5-sonnet-v2@20241022 ", + messages=anthropic_messages, + client=client, + ) + else: + response = await litellm.acompletion( + model="vertex_ai/claude-3-5-sonnet-v2@20241022 ", + messages=anthropic_messages, + client=client, + ) except Exception as e: print(f"Error: {e}") diff --git a/tests/local_testing/test_stream_chunk_builder.py b/tests/local_testing/test_stream_chunk_builder.py index 8e7cfcf9ed27..28b5a3badbc2 100644 --- a/tests/local_testing/test_stream_chunk_builder.py +++ b/tests/local_testing/test_stream_chunk_builder.py @@ -730,7 +730,6 @@ def test_stream_chunk_builder_openai_audio_output_usage(): usage_dict == response_usage_dict ), f"\nExpected: {usage_dict}\nGot: {response_usage_dict}" - def test_stream_chunk_builder_empty_initial_chunk(): from litellm.litellm_core_utils.streaming_chunk_builder_utils import ( ChunkProcessor,