Skip to content

Commit

Permalink
Fix passing top_k parameter for Bedrock Anthropic models (#8131) (#8269)
Browse files Browse the repository at this point in the history
* Fix Bedrock Anthropic topK bug

* Remove extra import

* Add unit test + make tests mocked

* Fix camel case

* Fix tests to remove exception handling

Co-authored-by: vibhavbhat <vibhavb00@gmail.com>
  • Loading branch information
ishaan-jaff and vibhavbhat authored Feb 5, 2025
1 parent d367f42 commit b965d9b
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 21 deletions.
47 changes: 26 additions & 21 deletions litellm/llms/bedrock/chat/converse_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,27 @@ def _transform_inference_params(self, inference_params: dict) -> InferenceConfig
if "top_k" in inference_params:
inference_params["topK"] = inference_params.pop("top_k")
return InferenceConfig(**inference_params)

def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
base_model = self._get_base_model(model)

val_top_k = None
if "topK" in inference_params:
val_top_k = inference_params.pop("topK")
elif "top_k" in inference_params:
val_top_k = inference_params.pop("top_k")

if val_top_k:
if (base_model.startswith("anthropic")):
return {"top_k": val_top_k}
if base_model.startswith("amazon.nova"):
return {'inferenceConfig': {"topK": val_top_k}}

return {}

def _transform_request_helper(
self,
model: str,
system_content_blocks: List[SystemContentBlock],
optional_params: dict,
messages: Optional[List[AllMessageValues]] = None,
Expand All @@ -361,35 +379,20 @@ def _transform_request_helper(
)

inference_params = copy.deepcopy(optional_params)
additional_request_keys = []
additional_request_params = {}
supported_converse_params = list(
AmazonConverseConfig.__annotations__.keys()
) + ["top_k"]
supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"]
total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params
inference_params.pop("json_mode", None) # used for handling json_schema

# send all model-specific params in 'additional_request_params'
for k, v in inference_params.items():
if (
k not in supported_converse_params
and k not in supported_tool_call_params
and k not in supported_guardrail_params
):
additional_request_params[k] = v
additional_request_keys.append(k)
for key in additional_request_keys:
inference_params.pop(key, None)
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params}
inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params}

if "topK" in inference_params:
additional_request_params["inferenceConfig"] = {
"topK": inference_params.pop("topK")
}
elif "top_k" in inference_params:
additional_request_params["inferenceConfig"] = {
"topK": inference_params.pop("top_k")
}
# Only set the topK value in for models that support it
additional_request_params.update(self._handle_top_k_value(model, inference_params))

bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", [])
Expand Down Expand Up @@ -437,6 +440,7 @@ async def _async_transform_request(
## TRANSFORMATION ##

_data: CommonRequestObject = self._transform_request_helper(
model=model,
system_content_blocks=system_content_blocks,
optional_params=optional_params,
messages=messages,
Expand Down Expand Up @@ -483,6 +487,7 @@ def _transform_request(
messages, system_content_blocks = self._transform_system_message(messages)

_data: CommonRequestObject = self._transform_request_helper(
model=model,
system_content_blocks=system_content_blocks,
optional_params=optional_params,
messages=messages,
Expand Down
69 changes: 69 additions & 0 deletions tests/llm_translation/test_bedrock_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2580,3 +2580,72 @@ def test_bedrock_custom_deepseek():
except Exception as e:
print(f"Error: {str(e)}")
raise e

@pytest.mark.parametrize(
"model, expected_output",
[
("bedrock/anthropic.claude-3-sonnet-20240229-v1:0", {"top_k": 3}),
("bedrock/converse/us.amazon.nova-pro-v1:0", {'inferenceConfig': {"topK": 3}}),
("bedrock/meta.llama3-70b-instruct-v1:0", {}),
]
)
def test_handle_top_k_value_helper(model, expected_output):
assert litellm.AmazonConverseConfig()._handle_top_k_value(model, {"topK": 3}) == expected_output
assert litellm.AmazonConverseConfig()._handle_top_k_value(model, {"top_k": 3}) == expected_output

@pytest.mark.parametrize(
"model, expected_params",
[
("bedrock/anthropic.claude-3-sonnet-20240229-v1:0", {"top_k": 2}),
("bedrock/converse/us.amazon.nova-pro-v1:0", {'inferenceConfig': {"topK": 2}}),
("bedrock/meta.llama3-70b-instruct-v1:0", {}),
("bedrock/mistral.mistral-7b-instruct-v0:2", {}),
]
)
def test_bedrock_top_k_param(model, expected_params):
import json

client = HTTPHandler()

with patch.object(client, "post") as mock_post:
mock_response = Mock()

if ("mistral" in model):
mock_response.text = json.dumps({"outputs": [{"text": "Here's a joke...", "stop_reason": "stop"}]})
else:
mock_response.text = json.dumps(
{
"output": {
"message": {
"role": "assistant",
"content": [
{
"text": "Here's a joke..."
}
]
}
},
"usage": {"inputTokens": 12, "outputTokens": 6, "totalTokens": 18},
"stopReason": "stop"
}
)

mock_response.status_code = 200
# Add required response attributes
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json = lambda: json.loads(mock_response.text)
mock_post.return_value = mock_response


litellm.completion(
model=model,
messages=[{"role": "user", "content": "Hello, world!"}],
top_k=2,
client=client
)
data = json.loads(mock_post.call_args.kwargs["data"])
if ("mistral" in model):
assert (data["top_k"] == 2)
else:
assert (data["additionalModelRequestFields"] == expected_params)

0 comments on commit b965d9b

Please sign in to comment.