Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support configurate openai compatible stream tool call #3467

Merged
merged 1 commit into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,14 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
features = []

function_calling_type = credentials.get('function_calling_type', 'no_call')
if function_calling_type == 'function_call':
if function_calling_type in ['function_call']:
features.append(ModelFeature.TOOL_CALL)
endpoint_url = credentials["endpoint_url"]
# if not endpoint_url.endswith('/'):
# endpoint_url += '/'
# if 'https://api.openai.com/v1/' == endpoint_url:
# features.append(ModelFeature.STREAM_TOOL_CALL)
elif function_calling_type in ['tool_call']:
features.append(ModelFeature.MULTI_TOOL_CALL)

stream_function_calling = credentials.get('stream_function_calling', 'supported')
if stream_function_calling == 'supported':
features.append(ModelFeature.STREAM_TOOL_CALL)

vision_support = credentials.get('vision_support', 'not_support')
if vision_support == 'support':
Expand Down Expand Up @@ -386,29 +387,37 @@ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, f

def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_call_id: str):
tool_call = next(
(tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None
)
if not tool_call_id:
return tools_calls[-1]

tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id='',
type='function',
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name='',
arguments=''
name="",
arguments=""
)
)
tools_calls.append(tool_call)

return tool_call

for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.id)
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
tool_call.id = new_tool_call.id
tool_call.type = new_tool_call.type
tool_call.function.name = new_tool_call.function.name
tool_call.function.arguments += new_tool_call.function.arguments
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments

finish_reason = 'Unknown'

for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
if chunk:
Expand Down Expand Up @@ -438,7 +447,17 @@ def get_tool_call(tool_call_id: str):
delta = choice['delta']
delta_content = delta.get('content')

assistant_message_tool_calls = delta.get('tool_calls', None)
assistant_message_tool_calls = None

if 'tool_calls' in delta and credentials.get('function_calling_type', 'no_call') == 'tool_call':
assistant_message_tool_calls = delta.get('tool_calls', None)
elif 'function_call' in delta and credentials.get('function_calling_type', 'no_call') == 'function_call':
assistant_message_tool_calls = [{
'id': 'tool_call_id',
'type': 'function',
'function': delta.get('function_call', {})
}]

# assistant_message_function_call = delta.delta.function_call

# extract tool calls from response
Expand All @@ -449,15 +468,13 @@ def get_tool_call(tool_call_id: str):
if delta_content is None or delta_content == '':
continue

# function_call = self._extract_response_function_call(assistant_message_function_call)
# tool_calls = [function_call] if function_call else []

# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta_content,
tool_calls=tool_calls if assistant_message_tool_calls else []
)

# reset tool calls
tool_calls = []
full_assistant_content += delta_content
elif 'text' in choice:
choice_text = choice.get('text', '')
Expand All @@ -470,37 +487,36 @@ def get_tool_call(tool_call_id: str):
else:
continue

# check payload indicator for completion
if finish_reason is not None:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=AssistantPromptMessage(
tool_calls=tools_calls,
),
finish_reason=finish_reason
)
)

yield create_final_llm_result_chunk(
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
finish_reason=finish_reason
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
)
)
)

chunk_index += 1

if tools_calls:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=AssistantPromptMessage(
tool_calls=tools_calls,
content=""
),
)
)

yield create_final_llm_result_chunk(
index=chunk_index,
message=AssistantPromptMessage(content=""),
finish_reason=finish_reason
)

def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> LLMResult:

Expand Down Expand Up @@ -757,13 +773,13 @@ def _extract_response_tool_calls(self,
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call["function"]["name"],
arguments=response_tool_call["function"]["arguments"]
name=response_tool_call.get("function", {}).get("name", ""),
arguments=response_tool_call.get("function", {}).get("arguments", "")
)

tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call["id"],
type=response_tool_call["type"],
id=response_tool_call.get("id", ""),
type=response_tool_call.get("type", ""),
function=function
)
tool_calls.append(tool_call)
Expand All @@ -781,12 +797,12 @@ def _extract_response_function_call(self, response_function_call) \
tool_call = None
if response_function_call:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call['name'],
arguments=response_function_call['arguments']
name=response_function_call.get('name', ''),
arguments=response_function_call.get('arguments', '')
)

tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call['name'],
id=response_function_call.get('id', ''),
type="function",
function=function
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,32 @@ model_credential_schema:
default: no_call
options:
- value: function_call
label:
en_US: Function Call
zh_Hans: Function Call
- value: tool_call
label:
en_US: Tool Call
zh_Hans: Tool Call
- value: no_call
label:
en_US: Not Support
zh_Hans: 不支持
- variable: stream_function_calling
show_on:
- variable: __model_type
value: llm
label:
en_US: Stream function calling
type: select
required: false
default: not_supported
options:
- value: supported
label:
en_US: Support
zh_Hans: 支持
# - value: tool_call
# label:
# en_US: Tool Call
# zh_Hans: Tool Call
- value: no_call
- value: not_supported
label:
en_US: Not Support
zh_Hans: 不支持
Expand Down
Loading