diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index c4af958a3cdd52..4886b62d72d922 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -170,13 +170,14 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode features = [] function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type == 'function_call': + if function_calling_type in ['function_call']: features.append(ModelFeature.TOOL_CALL) - endpoint_url = credentials["endpoint_url"] - # if not endpoint_url.endswith('/'): - # endpoint_url += '/' - # if 'https://api.openai.com/v1/' == endpoint_url: - # features.append(ModelFeature.STREAM_TOOL_CALL) + elif function_calling_type in ['tool_call']: + features.append(ModelFeature.MULTI_TOOL_CALL) + + stream_function_calling = credentials.get('stream_function_calling', 'supported') + if stream_function_calling == 'supported': + features.append(ModelFeature.STREAM_TOOL_CALL) vision_support = credentials.get('vision_support', 'not_support') if vision_support == 'support': @@ -386,29 +387,37 @@ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, f def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): def get_tool_call(tool_call_id: str): - tool_call = next( - (tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None - ) + if not tool_call_id: + return tools_calls[-1] + + tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None) if tool_call is None: tool_call = AssistantPromptMessage.ToolCall( - id='', - type='function', + id=tool_call_id, + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name='', - arguments='' + name="", + arguments="" ) ) tools_calls.append(tool_call) + return tool_call for new_tool_call in new_tool_calls: # get tool call - tool_call = get_tool_call(new_tool_call.id) + tool_call = get_tool_call(new_tool_call.function.name) # update tool call - tool_call.id = new_tool_call.id - tool_call.type = new_tool_call.type - tool_call.function.name = new_tool_call.function.name - tool_call.function.arguments += new_tool_call.function.arguments + if new_tool_call.id: + tool_call.id = new_tool_call.id + if new_tool_call.type: + tool_call.type = new_tool_call.type + if new_tool_call.function.name: + tool_call.function.name = new_tool_call.function.name + if new_tool_call.function.arguments: + tool_call.function.arguments += new_tool_call.function.arguments + + finish_reason = 'Unknown' for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): if chunk: @@ -438,7 +447,17 @@ def get_tool_call(tool_call_id: str): delta = choice['delta'] delta_content = delta.get('content') - assistant_message_tool_calls = delta.get('tool_calls', None) + assistant_message_tool_calls = None + + if 'tool_calls' in delta and credentials.get('function_calling_type', 'no_call') == 'tool_call': + assistant_message_tool_calls = delta.get('tool_calls', None) + elif 'function_call' in delta and credentials.get('function_calling_type', 'no_call') == 'function_call': + assistant_message_tool_calls = [{ + 'id': 'tool_call_id', + 'type': 'function', + 'function': delta.get('function_call', {}) + }] + # assistant_message_function_call = delta.delta.function_call # extract tool calls from response @@ -449,15 +468,13 @@ def get_tool_call(tool_call_id: str): if delta_content is None or delta_content == '': continue - # function_call = self._extract_response_function_call(assistant_message_function_call) - # tool_calls = [function_call] if function_call else [] - # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( content=delta_content, - tool_calls=tool_calls if assistant_message_tool_calls else [] ) + # reset tool calls + tool_calls = [] full_assistant_content += delta_content elif 'text' in choice: choice_text = choice.get('text', '') @@ -470,37 +487,36 @@ def get_tool_call(tool_call_id: str): else: continue - # check payload indicator for completion - if finish_reason is not None: - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - ), - finish_reason=finish_reason - ) - ) - - yield create_final_llm_result_chunk( + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - finish_reason=finish_reason - ) - else: - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=chunk_index, - message=assistant_prompt_message, - ) ) + ) chunk_index += 1 + if tools_calls: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=AssistantPromptMessage( + tool_calls=tools_calls, + content="" + ), + ) + ) + + yield create_final_llm_result_chunk( + index=chunk_index, + message=AssistantPromptMessage(content=""), + finish_reason=finish_reason + ) + def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult: @@ -757,13 +773,13 @@ def _extract_response_tool_calls(self, if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call["function"]["name"], - arguments=response_tool_call["function"]["arguments"] + name=response_tool_call.get("function", {}).get("name", ""), + arguments=response_tool_call.get("function", {}).get("arguments", "") ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call["id"], - type=response_tool_call["type"], + id=response_tool_call.get("id", ""), + type=response_tool_call.get("type", ""), function=function ) tool_calls.append(tool_call) @@ -781,12 +797,12 @@ def _extract_response_function_call(self, response_function_call) \ tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call['name'], - arguments=response_function_call['arguments'] + name=response_function_call.get('name', ''), + arguments=response_function_call.get('arguments', '') ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call['name'], + id=response_function_call.get('id', ''), type="function", function=function ) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml index cd53e149422fec..69bed9603902a6 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml @@ -86,14 +86,32 @@ model_credential_schema: default: no_call options: - value: function_call + label: + en_US: Function Call + zh_Hans: Function Call + - value: tool_call + label: + en_US: Tool Call + zh_Hans: Tool Call + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - variable: stream_function_calling + show_on: + - variable: __model_type + value: llm + label: + en_US: Stream function calling + type: select + required: false + default: not_supported + options: + - value: supported label: en_US: Support zh_Hans: 支持 -# - value: tool_call -# label: -# en_US: Tool Call -# zh_Hans: Tool Call - - value: no_call + - value: not_supported label: en_US: Not Support zh_Hans: 不支持