Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not record tool call id if no tool has been called #571

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
.PHONY: quality style test docs utils

PYTHON ?= python

check_dirs := examples src tests utils

# Check code quality of the source code
quality:
ruff check $(check_dirs)
ruff format --check $(check_dirs)
python utils/check_tests_in_ci.py
$(PYTHON) -m ruff check $(check_dirs)
$(PYTHON) -m ruff format --check $(check_dirs)
$(PYTHON) utils/check_tests_in_ci.py

# Format source code automatically
style:
ruff check $(check_dirs) --fix
ruff format $(check_dirs)
$(PYTHON) -m ruff check $(check_dirs) --fix
$(PYTHON) -m ruff format $(check_dirs)

# Run smolagents tests
test:
pytest ./tests/
$(PYTHON) -m pytest ./tests/
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ test = [
dev = [
"smolagents[quality,test]",
"sqlalchemy", # for ./examples
"mlx-lm",
]

[tool.pytest.ini_options]
Expand Down Expand Up @@ -111,4 +112,4 @@ lines-after-imports = 2

[project.scripts]
smolagent = "smolagents.cli:main"
webagent = "smolagents.vision_web_browser:main"
webagent = "smolagents.vision_web_browser:main"
16 changes: 6 additions & 10 deletions src/smolagents/memory.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from dataclasses import asdict, dataclass
from logging import getLogger
from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, Union
from typing import Any, List, TypedDict, Union

from smolagents.models import ChatMessage, MessageRole
from smolagents.monitoring import AgentLogger, LogLevel
from smolagents.utils import AgentError, make_json_serializable


if TYPE_CHECKING:
from smolagents.models import ChatMessage
from smolagents.monitoring import AgentLogger


logger = getLogger(__name__)


Expand Down Expand Up @@ -42,7 +37,7 @@ class MemoryStep:
def dict(self):
return asdict(self)

def to_messages(self, **kwargs) -> List[Dict[str, Any]]:
def to_messages(self, **kwargs) -> List[Message]:
raise NotImplementedError


Expand All @@ -55,7 +50,7 @@ class ActionStep(MemoryStep):
step_number: int | None = None
error: AgentError | None = None
duration: float | None = None
model_output_message: ChatMessage = None
model_output_message: ChatMessage | None = None
model_output: str | None = None
observations: str | None = None
observations_images: List[str] | None = None
Expand All @@ -78,7 +73,7 @@ def dict(self):
}

def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Message]:
messages = []
messages: List[Message] = []
if self.model_input_messages is not None and show_model_input_messages:
messages.append(Message(role=MessageRole.SYSTEM, content=self.model_input_messages))
if self.model_output is not None and not summary_mode:
Expand Down Expand Up @@ -106,7 +101,8 @@ def to_messages(self, summary_mode: bool = False, show_model_input_messages: boo
content=[
{
"type": "text",
"text": f"Call id: {self.tool_calls[0].id}\nObservation:\n{self.observations}",
"text": (f"Call id: {self.tool_calls[0].id}\n" if self.tool_calls else "")
+ f"Observation:\n{self.observations}",
}
],
)
Expand Down
20 changes: 10 additions & 10 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def from_hf_api(cls, message, raw) -> "ChatMessage":
return cls(role=message.role, content=message.content, tool_calls=tool_calls, raw=raw)

@classmethod
def from_dict(cls, data: dict) -> "ChatMessage":
def from_dict(cls, data: dict[str, Any]) -> "ChatMessage":
if data.get("tool_calls"):
tool_calls = [
ChatMessageToolCall(
Expand All @@ -123,11 +123,11 @@ def dict(self):
def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
if isinstance(arguments, dict):
return arguments
else:
try:
return json.loads(arguments)
except Exception:
return arguments

try:
return json.loads(arguments)
except Exception:
return arguments


def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage:
Expand Down Expand Up @@ -155,7 +155,7 @@ def roles(cls):
}


def get_tool_json_schema(tool: Tool) -> Dict:
def get_tool_json_schema(tool: Tool) -> dict[str, Any]:
properties = deepcopy(tool.inputs)
required = []
for key, value in properties.items():
Expand Down Expand Up @@ -298,7 +298,7 @@ def _prepare_completion_kwargs(

return completion_kwargs

def get_token_counts(self) -> Dict[str, int]:
def get_token_counts(self) -> dict[str, int | None]:
return {
"input_token_count": self.last_input_token_count,
"output_token_count": self.last_output_token_count,
Expand Down Expand Up @@ -331,11 +331,11 @@ def __call__(
"""
pass # To be implemented in child classes!

def to_dict(self) -> Dict:
def to_dict(self) -> dict[str, Any]:
"""
Converts the model into a JSON-compatible dictionary.
"""
model_dictionary = {
model_dictionary: dict[str, Any] = {
**self.kwargs,
"last_input_token_count": self.last_input_token_count,
"last_output_token_count": self.last_output_token_count,
Expand Down
52 changes: 44 additions & 8 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
populate_template,
)
from smolagents.default_tools import DuckDuckGoSearchTool, FinalAnswerTool, PythonInterpreterTool, VisitWebpageTool
from smolagents.memory import PlanningStep
from smolagents.memory import ActionStep, PlanningStep
from smolagents.models import (
ChatMessage,
ChatMessageToolCall,
Expand Down Expand Up @@ -146,7 +146,7 @@ def __call__(self, messages, tools_to_call_from=None, stop_sequences=None, gramm
)


def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
def fake_code_model(messages, stop_sequences=None, grammar=None) -> ChatMessage:
prompt = str(messages)
if "special_marker" not in prompt:
return ChatMessage(
Expand All @@ -172,7 +172,7 @@ def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
)


def fake_code_model_error(messages, stop_sequences=None) -> str:
def fake_code_model_error(messages, stop_sequences=None) -> ChatMessage:
prompt = str(messages)
if "special_marker" not in prompt:
return ChatMessage(
Expand Down Expand Up @@ -202,7 +202,7 @@ def error_function():
)


def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
def fake_code_model_syntax_error(messages, stop_sequences=None) -> ChatMessage:
prompt = str(messages)
if "special_marker" not in prompt:
return ChatMessage(
Expand Down Expand Up @@ -231,7 +231,7 @@ def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
)


def fake_code_model_import(messages, stop_sequences=None) -> str:
def fake_code_model_import(messages, stop_sequences=None) -> ChatMessage:
return ChatMessage(
role="assistant",
content="""
Expand All @@ -245,7 +245,7 @@ def fake_code_model_import(messages, stop_sequences=None) -> str:
)


def fake_code_functiondef(messages, stop_sequences=None) -> str:
def fake_code_functiondef(messages, stop_sequences=None) -> ChatMessage:
prompt = str(messages)
if "special_marker" not in prompt:
return ChatMessage(
Expand Down Expand Up @@ -276,7 +276,7 @@ def moving_average(x, w):
)


def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> ChatMessage:
return ChatMessage(
role="assistant",
content="""
Expand All @@ -290,7 +290,7 @@ def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) ->
)


def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> ChatMessage:
return ChatMessage(
role="assistant",
content="""
Expand All @@ -304,6 +304,18 @@ def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> st
)


def fake_no_valid_code_block(messages, stop_sequences=None, grammar=None) -> ChatMessage:
return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
Code:
I am Incorrect Python Code !rint(result)
```
""",
)


class AgentTests(unittest.TestCase):
def test_fake_toolcalling_agent(self):
agent = ToolCallingAgent(tools=[PythonInterpreterTool()], model=FakeToolCallModel())
Expand Down Expand Up @@ -735,6 +747,26 @@ def test_planning_step(self, step, expected_messages_list):
for content, expected_content in zip(message["content"], expected_message["content"]):
assert content == expected_content

def test_agent_memory_to_messages_suceeds_when_tool_fails_but_obeservation_is_set(self):
tool = PythonInterpreterTool()

def _fake_callback(memory_step: ActionStep, agent: CodeAgent) -> None:
memory_step.observations = "observed something"

agent = CodeAgent(
tools=[tool],
model=fake_no_valid_code_block,
step_callbacks=[_fake_callback],
)

# Perform a task. A tool call will fail since the fake model returns invalid response, but a _fake_callback
# sets an observation
agent.run("some task")

# This should not fail, even though no tool call have been recorded, and observation came from a callback
for s in agent.memory.steps:
s.to_messages()

@pytest.mark.parametrize(
"images, expected_messages_list",
[
Expand Down Expand Up @@ -904,6 +936,10 @@ def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:


class MultiAgentsTests(unittest.TestCase):
@pytest.fixture(autouse=True)
def initdir(self, tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)

def test_multiagents_save(self):
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens=2096, temperature=0.5)

Expand Down