Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix passing top_k parameter for Bedrock Anthropic models #8131

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions litellm/llms/bedrock/chat/converse_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,27 @@ def _transform_inference_params(self, inference_params: dict) -> InferenceConfig
if "top_k" in inference_params:
inference_params["topK"] = inference_params.pop("top_k")
return InferenceConfig(**inference_params)

def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
vibhavbhat marked this conversation as resolved.
Show resolved Hide resolved
base_model = self._get_base_model(model)

val_top_k = None
if "topK" in inference_params:
val_top_k = inference_params.pop("topK")
elif "top_k" in inference_params:
val_top_k = inference_params.pop("top_k")

if val_top_k:
if (base_model.startswith("anthropic")):
return {"top_k": val_top_k}
if base_model.startswith("amazon.nova"):
return {'inferenceConfig': {"topK": val_top_k}}

return {}

def _transform_request_helper(
self,
model: str,
system_content_blocks: List[SystemContentBlock],
optional_params: dict,
messages: Optional[List[AllMessageValues]] = None,
Expand All @@ -361,35 +379,20 @@ def _transform_request_helper(
)

inference_params = copy.deepcopy(optional_params)
additional_request_keys = []
additional_request_params = {}
supported_converse_params = list(
AmazonConverseConfig.__annotations__.keys()
) + ["top_k"]
supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"]
total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params
inference_params.pop("json_mode", None) # used for handling json_schema

# send all model-specific params in 'additional_request_params'
for k, v in inference_params.items():
if (
k not in supported_converse_params
and k not in supported_tool_call_params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this deletion required ? I'm pretty sure we need to filter our specific guardrail params

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to simplify the code a bit (using list comprehension instead of adding and then popping from the dict). From my understanding the before and after logic should be functionally equivalent, but if I'm wrong, I can revert it

and k not in supported_guardrail_params
):
additional_request_params[k] = v
additional_request_keys.append(k)
for key in additional_request_keys:
inference_params.pop(key, None)
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params}
inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params}

if "topK" in inference_params:
additional_request_params["inferenceConfig"] = {
"topK": inference_params.pop("topK")
}
elif "top_k" in inference_params:
additional_request_params["inferenceConfig"] = {
"topK": inference_params.pop("top_k")
}
# Only set the topK value in for models that support it
additional_request_params.update(self._handle_top_k_value(model, inference_params))

bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", [])
Expand Down Expand Up @@ -437,6 +440,7 @@ async def _async_transform_request(
## TRANSFORMATION ##

_data: CommonRequestObject = self._transform_request_helper(
model=model,
system_content_blocks=system_content_blocks,
optional_params=optional_params,
messages=messages,
Expand Down Expand Up @@ -483,6 +487,7 @@ def _transform_request(
messages, system_content_blocks = self._transform_system_message(messages)

_data: CommonRequestObject = self._transform_request_helper(
model=model,
system_content_blocks=system_content_blocks,
optional_params=optional_params,
messages=messages,
Expand Down
69 changes: 69 additions & 0 deletions tests/llm_translation/test_bedrock_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2580,3 +2580,72 @@ def test_bedrock_custom_deepseek():
except Exception as e:
print(f"Error: {str(e)}")
raise e

@pytest.mark.parametrize(
vibhavbhat marked this conversation as resolved.
Show resolved Hide resolved
"model, expected_output",
[
("bedrock/anthropic.claude-3-sonnet-20240229-v1:0", {"top_k": 3}),
("bedrock/converse/us.amazon.nova-pro-v1:0", {'inferenceConfig': {"topK": 3}}),
("bedrock/meta.llama3-70b-instruct-v1:0", {}),
]
)
def test_handle_top_k_value_helper(model, expected_output):
assert litellm.AmazonConverseConfig()._handle_top_k_value(model, {"topK": 3}) == expected_output
assert litellm.AmazonConverseConfig()._handle_top_k_value(model, {"top_k": 3}) == expected_output

@pytest.mark.parametrize(
"model, expected_params",
[
("bedrock/anthropic.claude-3-sonnet-20240229-v1:0", {"top_k": 2}),
("bedrock/converse/us.amazon.nova-pro-v1:0", {'inferenceConfig': {"topK": 2}}),
("bedrock/meta.llama3-70b-instruct-v1:0", {}),
("bedrock/mistral.mistral-7b-instruct-v0:2", {}),

]
)
def test_bedrock_top_k_param(model, expected_params):
import json

client = HTTPHandler()

with patch.object(client, "post") as mock_post:
mock_response = Mock()

if ("mistral" in model):
mock_response.text = json.dumps({"outputs": [{"text": "Here's a joke...", "stop_reason": "stop"}]})
else:
mock_response.text = json.dumps(
{
"output": {
"message": {
"role": "assistant",
"content": [
{
"text": "Here's a joke..."
}
]
}
},
"usage": {"inputTokens": 12, "outputTokens": 6, "totalTokens": 18},
"stopReason": "stop"
}
)

mock_response.status_code = 200
# Add required response attributes
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json = lambda: json.loads(mock_response.text)
mock_post.return_value = mock_response


litellm.completion(
model=model,
messages=[{"role": "user", "content": "Hello, world!"}],
top_k=2,
client=client
)
data = json.loads(mock_post.call_args.kwargs["data"])
if ("mistral" in model):
assert (data["top_k"] == 2)
else:
assert (data["additionalModelRequestFields"] == expected_params)