From 2f6c9f45ccc5e77cde5a740f180e65824b53e6e6 Mon Sep 17 00:00:00 2001 From: Vibhav Bhat Date: Thu, 30 Jan 2025 20:31:32 -0800 Subject: [PATCH 1/5] Fix Bedrock Anthropic topK bug --- .../bedrock/chat/converse_transformation.py | 49 ++++++++++--------- .../test_bedrock_completion.py | 17 +++++++ 2 files changed, 44 insertions(+), 22 deletions(-) diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index 521dd20854bb..6987e9649fa0 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -31,7 +31,7 @@ OpenAIMessageContentListBlock, ) from litellm.types.utils import ModelResponse, Usage -from litellm.utils import add_dummy_tool, has_tool_call_blocks +from litellm.utils import CustomStreamWrapper, add_dummy_tool, has_tool_call_blocks from ..common_utils import ( AmazonBedrockGlobalConfig, @@ -332,9 +332,27 @@ def _transform_inference_params(self, inference_params: dict) -> InferenceConfig if "top_k" in inference_params: inference_params["topK"] = inference_params.pop("top_k") return InferenceConfig(**inference_params) + + def _handle_top_k_value(self, model: str, inference_params: dict) -> dict: + base_model = self._get_base_model(model) + + valTopK = None + if "topK" in inference_params: + valTopK = inference_params.pop("topK") + elif "top_k" in inference_params: + valTopK = inference_params.pop("top_k") + + if valTopK: + if (base_model.startswith("anthropic")): + return {"top_k": valTopK} + elif base_model.startswith("amazon.nova"): + return {'inferenceConfig': {"topK": valTopK}} + + return {} def _transform_request_helper( self, + model: str, system_content_blocks: List[SystemContentBlock], optional_params: dict, messages: Optional[List[AllMessageValues]] = None, @@ -361,35 +379,20 @@ def _transform_request_helper( ) inference_params = copy.deepcopy(optional_params) - additional_request_keys = [] - additional_request_params = {} supported_converse_params = list( AmazonConverseConfig.__annotations__.keys() ) + ["top_k"] supported_tool_call_params = ["tools", "tool_choice"] supported_guardrail_params = ["guardrailConfig"] + total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params inference_params.pop("json_mode", None) # used for handling json_schema - # send all model-specific params in 'additional_request_params' - for k, v in inference_params.items(): - if ( - k not in supported_converse_params - and k not in supported_tool_call_params - and k not in supported_guardrail_params - ): - additional_request_params[k] = v - additional_request_keys.append(k) - for key in additional_request_keys: - inference_params.pop(key, None) + # keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params' + additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params} + inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params} - if "topK" in inference_params: - additional_request_params["inferenceConfig"] = { - "topK": inference_params.pop("topK") - } - elif "top_k" in inference_params: - additional_request_params["inferenceConfig"] = { - "topK": inference_params.pop("top_k") - } + # Only set the topK value in for models that support it + additional_request_params.update(self._handle_top_k_value(model, inference_params)) bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( inference_params.pop("tools", []) @@ -437,6 +440,7 @@ async def _async_transform_request( ## TRANSFORMATION ## _data: CommonRequestObject = self._transform_request_helper( + model=model, system_content_blocks=system_content_blocks, optional_params=optional_params, messages=messages, @@ -483,6 +487,7 @@ def _transform_request( messages, system_content_blocks = self._transform_system_message(messages) _data: CommonRequestObject = self._transform_request_helper( + model=model, system_content_blocks=system_content_blocks, optional_params=optional_params, messages=messages, diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index 5f9c01f7bb1d..731d5717ba11 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -2580,3 +2580,20 @@ def test_bedrock_custom_deepseek(): except Exception as e: print(f"Error: {str(e)}") raise e + +@pytest.mark.parametrize( + "model", + [ + "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + "bedrock/converse/us.amazon.nova-pro-v1:0", + "bedrock/meta.llama3-70b-instruct-v1:0", + "bedrock/mistral.mistral-7b-instruct-v0:2", + ] +) +def test_bedrock_top_k(model): + litellm.completion( + model=model, + messages=[{"role": "user", "content": "Hello, world!"}], + top_k=2, + ) + From b00809310d4b37dd8394a965048470eb13800aa1 Mon Sep 17 00:00:00 2001 From: Vibhav Bhat Date: Thu, 30 Jan 2025 20:50:33 -0800 Subject: [PATCH 2/5] Remove extra import --- litellm/llms/bedrock/chat/converse_transformation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index 6987e9649fa0..fa6652085d23 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -31,7 +31,7 @@ OpenAIMessageContentListBlock, ) from litellm.types.utils import ModelResponse, Usage -from litellm.utils import CustomStreamWrapper, add_dummy_tool, has_tool_call_blocks +from litellm.utils import add_dummy_tool, has_tool_call_blocks from ..common_utils import ( AmazonBedrockGlobalConfig, From 2d5ae20ced7239e53788911cfde4295566ea1fde Mon Sep 17 00:00:00 2001 From: Vibhav Bhat Date: Fri, 31 Jan 2025 00:21:10 -0800 Subject: [PATCH 3/5] Add unit test + make tests mocked --- .../test_bedrock_completion.py | 49 ++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index 731d5717ba11..e1a40390fe55 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -2582,18 +2582,45 @@ def test_bedrock_custom_deepseek(): raise e @pytest.mark.parametrize( - "model", + "model, expected_output", [ - "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", - "bedrock/converse/us.amazon.nova-pro-v1:0", - "bedrock/meta.llama3-70b-instruct-v1:0", - "bedrock/mistral.mistral-7b-instruct-v0:2", + ("bedrock/anthropic.claude-3-sonnet-20240229-v1:0", {"top_k": 3}), + ("bedrock/converse/us.amazon.nova-pro-v1:0", {'inferenceConfig': {"topK": 3}}), + ("bedrock/meta.llama3-70b-instruct-v1:0", {}), ] ) -def test_bedrock_top_k(model): - litellm.completion( - model=model, - messages=[{"role": "user", "content": "Hello, world!"}], - top_k=2, - ) +def test_handle_top_k_value_helper(model, expected_output): + assert litellm.AmazonConverseConfig()._handle_top_k_value(model, {"topK": 3}) == expected_output + assert litellm.AmazonConverseConfig()._handle_top_k_value(model, {"top_k": 3}) == expected_output + +@pytest.mark.parametrize( + "model, expected_params", + [ + ("bedrock/anthropic.claude-3-sonnet-20240229-v1:0", {"top_k": 2}), + ("bedrock/converse/us.amazon.nova-pro-v1:0", {'inferenceConfig': {"topK": 2}}), + ("bedrock/meta.llama3-70b-instruct-v1:0", {}), + ("bedrock/mistral.mistral-7b-instruct-v0:2", {}), + + ] +) +def test_bedrock_top_k_param(model, expected_params): + import json + client = HTTPHandler() + + with patch.object(client, "post") as mock_post: + try: + litellm.completion( + model=model, + messages=[{"role": "user", "content": "Hello, world!"}], + top_k=2, + client=client + ) + except Exception as e: + print(e) + + data = json.loads(mock_post.call_args.kwargs["data"]) + if ("mistral" in model): + assert (data["top_k"] == 2) + else: + assert (data["additionalModelRequestFields"] == expected_params) From d726418b7720f02eef201927bdcd205c92f21f25 Mon Sep 17 00:00:00 2001 From: Vibhav Bhat Date: Fri, 31 Jan 2025 11:53:57 -0800 Subject: [PATCH 4/5] Fix camel case --- .../llms/bedrock/chat/converse_transformation.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index fa6652085d23..cab4a413d1d4 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -336,17 +336,17 @@ def _transform_inference_params(self, inference_params: dict) -> InferenceConfig def _handle_top_k_value(self, model: str, inference_params: dict) -> dict: base_model = self._get_base_model(model) - valTopK = None + val_top_k = None if "topK" in inference_params: - valTopK = inference_params.pop("topK") + val_top_k = inference_params.pop("topK") elif "top_k" in inference_params: - valTopK = inference_params.pop("top_k") + val_top_k = inference_params.pop("top_k") - if valTopK: + if val_top_k: if (base_model.startswith("anthropic")): - return {"top_k": valTopK} - elif base_model.startswith("amazon.nova"): - return {'inferenceConfig': {"topK": valTopK}} + return {"top_k": val_top_k} + if base_model.startswith("amazon.nova"): + return {'inferenceConfig': {"topK": val_top_k}} return {} From 5015e877ff4ff757e47f77431511ba79884470fb Mon Sep 17 00:00:00 2001 From: Vibhav Bhat Date: Mon, 3 Feb 2025 20:35:11 -0800 Subject: [PATCH 5/5] Fix tests to remove exception handling --- .../test_bedrock_completion.py | 43 +++++++++++++++---- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index e1a40390fe55..f9f6bdef5810 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -2609,16 +2609,41 @@ def test_bedrock_top_k_param(model, expected_params): client = HTTPHandler() with patch.object(client, "post") as mock_post: - try: - litellm.completion( - model=model, - messages=[{"role": "user", "content": "Hello, world!"}], - top_k=2, - client=client - ) - except Exception as e: - print(e) + mock_response = Mock() + + if ("mistral" in model): + mock_response.text = json.dumps({"outputs": [{"text": "Here's a joke...", "stop_reason": "stop"}]}) + else: + mock_response.text = json.dumps( + { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "text": "Here's a joke..." + } + ] + } + }, + "usage": {"inputTokens": 12, "outputTokens": 6, "totalTokens": 18}, + "stopReason": "stop" + } + ) + + mock_response.status_code = 200 + # Add required response attributes + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json = lambda: json.loads(mock_response.text) + mock_post.return_value = mock_response + + litellm.completion( + model=model, + messages=[{"role": "user", "content": "Hello, world!"}], + top_k=2, + client=client + ) data = json.loads(mock_post.call_args.kwargs["data"]) if ("mistral" in model): assert (data["top_k"] == 2)