Skip to content

Commit

Permalink
fix(utils.py): fix vertex ai optional param handling
Browse files Browse the repository at this point in the history
don't pass max retries to unsupported route

Fixes #8254
  • Loading branch information
krrishdholakia committed Feb 12, 2025
1 parent 5e58ae0 commit 04d07b0
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import Optional

import litellm
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig


class VertexAILlama3Config:
class VertexAILlama3Config(OpenAIGPTConfig):
"""
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
Expand Down Expand Up @@ -46,8 +47,13 @@ def get_config(cls):
and v is not None
}

def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def get_supported_openai_params(self, model: str):
supported_params = super().get_supported_openai_params(model=model)
try:
supported_params.remove("max_retries")
except KeyError:
pass
return supported_params

def map_openai_params(
self,
Expand All @@ -60,7 +66,7 @@ def map_openai_params(
non_default_params["max_tokens"] = non_default_params.pop(
"max_completion_tokens"
)
return litellm.OpenAIConfig().map_openai_params(
return super().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
Expand Down
63 changes: 34 additions & 29 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3166,51 +3166,56 @@ def _check_valid_arg(supported_params: List[str]):
else False
),
)
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_llama3_models:
optional_params = litellm.VertexAILlama3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_mistral_models:
if "codestral" in model:
optional_params = litellm.CodestralTextCompletionConfig().map_openai_params(
model=model,
elif custom_llm_provider == "vertex_ai":

if model in litellm.vertex_mistral_models:
if "codestral" in model:
optional_params = (
litellm.CodestralTextCompletionConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
)
else:
optional_params = litellm.MistralConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif model in litellm.vertex_ai_ai21_models:
optional_params = litellm.VertexAIAi21Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
else:
optional_params = litellm.MistralConfig().map_openai_params(
model=model,
else: # use generic openai-like param mapping
optional_params = litellm.VertexAILlama3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_ai_ai21_models:
optional_params = litellm.VertexAIAi21Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)

elif custom_llm_provider == "sagemaker":
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
optional_params = litellm.SagemakerConfig().map_openai_params(
Expand Down
11 changes: 11 additions & 0 deletions tests/llm_translation/test_optional_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,3 +1067,14 @@ def test_gemini_frequency_penalty():
model="gemini-1.5-flash", custom_llm_provider="gemini", frequency_penalty=0.5
)
assert optional_params["frequency_penalty"] == 0.5


def test_vertex_ai_ft_llama():
optional_params = get_optional_params(
model="1984786713414729728",
custom_llm_provider="vertex_ai",
frequency_penalty=0.5,
max_retries=10,
)
assert optional_params["frequency_penalty"] == 0.5
assert "max_retries" not in optional_params

0 comments on commit 04d07b0

Please sign in to comment.