diff --git a/engines/python/setup/djl_python/output_formatter.py b/engines/python/setup/djl_python/output_formatter.py index e3e233123..e47bbe812 100644 --- a/engines/python/setup/djl_python/output_formatter.py +++ b/engines/python/setup/djl_python/output_formatter.py @@ -323,9 +323,8 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput): "finish_reason": best_sequence.finish_reason, } elif chat_params and chat_params.tools and ( - parameters.get("tool_choice") == "auto" - or parameters.get("tool_choice") - is None) and parameters.get("tool_parser"): + chat_params.tool_choice == "auto" + or chat_params.tool_choice is None) and tool_parser: tool_call_info = tool_parser.extract_tool_calls(generated_text, request=chat_params) auto_tools_called = tool_call_info.tools_called @@ -408,9 +407,9 @@ def _jsonlines_chat_output_formatter(request_output: TextGenerationOutput): } }] delta = {"tool_calls": tool_calls} - elif parameters.get("tools") and (parameters.get("tool_choice") == "auto" - or parameters.get("tool_choice") is None - ) and parameters.get("tool_parser"): + elif chat_params and chat_params.tools and ( + chat_params.tool_choice == "auto" + or chat_params.tool_choice is None) and tool_parser: current_text = get_generated_text(best_sequence, request_output) previous_text = current_text[0:-len(next_token.text)] current_token_ids = [t.id for t in best_sequence.tokens]