Skip to content

Commit

Permalink
Intermediate state, harmonizing models
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Feb 26, 2025
1 parent eb7b933 commit 6c0f86d
Showing 1 changed file with 44 additions and 49 deletions.
93 changes: 44 additions & 49 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from PIL import Image

from .tools import Tool
from .utils import _is_package_available, encode_image_base64, make_image_url
from .utils import _is_package_available, encode_image_base64, make_image_url, parse_json_blob


if TYPE_CHECKING:
Expand Down Expand Up @@ -104,29 +104,6 @@ def from_hf_api(cls, message, raw) -> "ChatMessage":
tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls]
return cls(role=message.role, content=message.content, tool_calls=tool_calls, raw=raw)

@classmethod
def from_vllm_api(cls, output, tools_to_call_from) -> "ChatMessage":
if tools_to_call_from is None:
return cls(role="assistant", content=output)

if "Action:" in output:
output = output.split("Action:", 1)[1].strip()

parsed_output = json.loads(output)
tool_name = parsed_output.get("tool_name")
tool_arguments = parsed_output.get("tool_arguments")

return cls(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id="".join(random.choices("0123456789", k=5)),
type="function",
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
)
],
)

@classmethod
def from_dict(cls, data: dict) -> "ChatMessage":
Expand Down Expand Up @@ -264,9 +241,11 @@ def get_clean_message_list(


class Model:
def __init__(self, **kwargs):
def __init__(self, tool_name_key: str = "name", tool_arguments_key: str = "arguments", **kwargs):
self.last_input_token_count = None
self.last_output_token_count = None
self.tool_name_key = tool_name_key
self.tool_arguments_key = tool_arguments_key
self.kwargs = kwargs

def _prepare_completion_kwargs(
Expand Down Expand Up @@ -567,8 +546,24 @@ def __call__(
output = out[0].outputs[0].text
self.last_input_token_count = len(out[0].prompt_token_ids)
self.last_output_token_count = len(out[0].outputs[0].token_ids)

return ChatMessage.from_vllm_api(output)
if tools_to_call_from is None:
return ChatMessage(role="assistant", content=output)
else:
parsed_output = json.loads(output)
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id="".join(random.choices("0123456789", k=5)),
type="function",
function=ChatMessageToolCallDefinition(
name=parsed_output.get(self.tool_name_key, None),
arguments=parsed_output.get(self.tool_arguments_key, None)
),
)
],
)


class MLXModel(Model):
Expand Down Expand Up @@ -630,26 +625,6 @@ def __init__(
self.tool_name_key = tool_name_key
self.tool_arguments_key = tool_arguments_key

def _to_message(self, text, tools_to_call_from):
if tools_to_call_from:
# solution for extracting tool JSON without assuming a specific model output format
maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}"
parsed_text = json.loads(maybe_json)
tool_name = parsed_text.get(self.tool_name_key, None)
tool_arguments = parsed_text.get(self.tool_arguments_key, None)
if tool_name:
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id=uuid.uuid4(),
type="function",
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
)
],
)
return ChatMessage(role="assistant", content=text)

def __call__(
self,
Expand Down Expand Up @@ -689,9 +664,29 @@ def __call__(
stop_sequence_start = text.rfind(stop_sequence)
if stop_sequence_start != -1:
text = text[:stop_sequence_start]
return self._to_message(text, tools_to_call_from)
found_stop_sequence = True
break
if found_stop_sequence:
break

return self._to_message(text, tools_to_call_from)
if tools_to_call_from:
# Extracts tool JSON without assuming a specific model output format
parsed_text = parse_json_blob(text)
tool_name = parsed_text.get(self.tool_name_key, None)
tool_arguments = parsed_text.get(self.tool_arguments_key, None)
if tool_name:
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id=uuid.uuid4(),
type="function",
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
)
],
)
return ChatMessage(role="assistant", content=text)


class TransformersModel(Model):
Expand Down

0 comments on commit 6c0f86d

Please sign in to comment.