From c3c2d80fd5742ee3ccb99ea48a388b95e3d568be Mon Sep 17 00:00:00 2001 From: sysradium Date: Sun, 9 Feb 2025 19:53:18 +0100 Subject: [PATCH 1/9] Do not record tool call id if no tool has been called --- src/smolagents/memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/smolagents/memory.py b/src/smolagents/memory.py index 5875db596..b59de95d4 100644 --- a/src/smolagents/memory.py +++ b/src/smolagents/memory.py @@ -106,7 +106,7 @@ def to_messages(self, summary_mode: bool = False, show_model_input_messages: boo content=[ { "type": "text", - "text": f"Call id: {self.tool_calls[0].id}\nObservation:\n{self.observations}", + "text": f"f{'Call id: {self.tool_calls[0].id}\n' if self.tool_calls else ''}Observation:\n{self.observations}", } ], ) From a1829205f5d52196cf42053abca73dba43818218 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 13 Feb 2025 22:14:29 +0100 Subject: [PATCH 2/9] Update src/smolagents/memory.py Co-authored-by: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> --- src/smolagents/memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/smolagents/memory.py b/src/smolagents/memory.py index b59de95d4..7c9a68824 100644 --- a/src/smolagents/memory.py +++ b/src/smolagents/memory.py @@ -106,7 +106,7 @@ def to_messages(self, summary_mode: bool = False, show_model_input_messages: boo content=[ { "type": "text", - "text": f"f{'Call id: {self.tool_calls[0].id}\n' if self.tool_calls else ''}Observation:\n{self.observations}", + "text": f"Call id: {self.tool_calls[0].id}\n" if self.tool_calls else "" + f"Observation:\n{self.observations}", } ], ) From 24647adec610fe3ec2c4cfb658c0139abf4bf9bb Mon Sep 17 00:00:00 2001 From: sysradium Date: Thu, 13 Feb 2025 22:51:34 +0100 Subject: [PATCH 3/9] add regression test --- src/smolagents/memory.py | 4 +++- tests/test_agents.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/smolagents/memory.py b/src/smolagents/memory.py index 7c9a68824..5e2ef763b 100644 --- a/src/smolagents/memory.py +++ b/src/smolagents/memory.py @@ -106,7 +106,9 @@ def to_messages(self, summary_mode: bool = False, show_model_input_messages: boo content=[ { "type": "text", - "text": f"Call id: {self.tool_calls[0].id}\n" if self.tool_calls else "" + f"Observation:\n{self.observations}", + "text": f"Call id: {self.tool_calls[0].id}\n" + if self.tool_calls + else "" + f"Observation:\n{self.observations}", } ], ) diff --git a/tests/test_agents.py b/tests/test_agents.py index 0488071b0..8549b94a2 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -33,7 +33,7 @@ populate_template, ) from smolagents.default_tools import DuckDuckGoSearchTool, FinalAnswerTool, PythonInterpreterTool, VisitWebpageTool -from smolagents.memory import PlanningStep +from smolagents.memory import PlanningStep, ActionStep from smolagents.models import ( ChatMessage, ChatMessageToolCall, @@ -304,6 +304,18 @@ def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> st ) +def fake_no_valid_code_block(messages, stop_sequences=None, grammar=None) -> ChatMessage: + return ChatMessage( + role="assistant", + content=""" +Thought: I should multiply 2 by 3.6452. special_marker +Code: + I am Incorrect Python Code !rint(result) +``` +""", + ) + + class AgentTests(unittest.TestCase): def test_fake_toolcalling_agent(self): agent = ToolCallingAgent(tools=[PythonInterpreterTool()], model=FakeToolCallModel()) @@ -735,6 +747,22 @@ def test_planning_step(self, step, expected_messages_list): for content, expected_content in zip(message["content"], expected_message["content"]): assert content == expected_content + def test_agent_memory_to_messages_suceeds_when_tool_fails_by_obeservation_is_set(self): + tool = PythonInterpreterTool() + + def _fake_callback(memory_step: ActionStep, agent: CodeAgent) -> None: + memory_step.observations = "observed something" + + agent = CodeAgent( + tools=[tool], + model=fake_no_valid_code_block, + step_callbacks=[_fake_callback], + ) + agent.run("some task") + + for s in agent.memory.steps: + s.to_messages() + @pytest.mark.parametrize( "images, expected_messages_list", [ From be444e87fabaa672a1317592f52b2c6a216d9adc Mon Sep 17 00:00:00 2001 From: sysradium Date: Thu, 13 Feb 2025 22:53:26 +0100 Subject: [PATCH 4/9] correct type hints for mocks --- src/smolagents/memory.py | 5 ++--- tests/test_agents.py | 14 +++++++------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/smolagents/memory.py b/src/smolagents/memory.py index 5e2ef763b..ac573a6e0 100644 --- a/src/smolagents/memory.py +++ b/src/smolagents/memory.py @@ -106,9 +106,8 @@ def to_messages(self, summary_mode: bool = False, show_model_input_messages: boo content=[ { "type": "text", - "text": f"Call id: {self.tool_calls[0].id}\n" - if self.tool_calls - else "" + f"Observation:\n{self.observations}", + "text": (f"Call id: {self.tool_calls[0].id}\n" if self.tool_calls else "") + + f"Observation:\n{self.observations}", } ], ) diff --git a/tests/test_agents.py b/tests/test_agents.py index 8549b94a2..0e32eebd7 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -146,7 +146,7 @@ def __call__(self, messages, tools_to_call_from=None, stop_sequences=None, gramm ) -def fake_code_model(messages, stop_sequences=None, grammar=None) -> str: +def fake_code_model(messages, stop_sequences=None, grammar=None) -> ChatMessage: prompt = str(messages) if "special_marker" not in prompt: return ChatMessage( @@ -172,7 +172,7 @@ def fake_code_model(messages, stop_sequences=None, grammar=None) -> str: ) -def fake_code_model_error(messages, stop_sequences=None) -> str: +def fake_code_model_error(messages, stop_sequences=None) -> ChatMessage: prompt = str(messages) if "special_marker" not in prompt: return ChatMessage( @@ -202,7 +202,7 @@ def error_function(): ) -def fake_code_model_syntax_error(messages, stop_sequences=None) -> str: +def fake_code_model_syntax_error(messages, stop_sequences=None) -> ChatMessage: prompt = str(messages) if "special_marker" not in prompt: return ChatMessage( @@ -231,7 +231,7 @@ def fake_code_model_syntax_error(messages, stop_sequences=None) -> str: ) -def fake_code_model_import(messages, stop_sequences=None) -> str: +def fake_code_model_import(messages, stop_sequences=None) -> ChatMessage: return ChatMessage( role="assistant", content=""" @@ -245,7 +245,7 @@ def fake_code_model_import(messages, stop_sequences=None) -> str: ) -def fake_code_functiondef(messages, stop_sequences=None) -> str: +def fake_code_functiondef(messages, stop_sequences=None) -> ChatMessage: prompt = str(messages) if "special_marker" not in prompt: return ChatMessage( @@ -276,7 +276,7 @@ def moving_average(x, w): ) -def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str: +def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> ChatMessage: return ChatMessage( role="assistant", content=""" @@ -290,7 +290,7 @@ def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> ) -def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str: +def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> ChatMessage: return ChatMessage( role="assistant", content=""" From 61dee8e98a1a37556efbf5fded0156151c318328 Mon Sep 17 00:00:00 2001 From: sysradium Date: Thu, 13 Feb 2025 23:06:37 +0100 Subject: [PATCH 5/9] allow python override in Makefile to be able to use it with uv --- Makefile | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index c8e7c04f6..63bb930bd 100644 --- a/Makefile +++ b/Makefile @@ -1,18 +1,20 @@ .PHONY: quality style test docs utils +PYTHON ?= python + check_dirs := examples src tests utils # Check code quality of the source code quality: - ruff check $(check_dirs) - ruff format --check $(check_dirs) - python utils/check_tests_in_ci.py + $(PYTHON) -m ruff check $(check_dirs) + $(PYTHON) -m ruff format --check $(check_dirs) + $(PYTHON) utils/check_tests_in_ci.py # Format source code automatically style: - ruff check $(check_dirs) --fix - ruff format $(check_dirs) + $(PYTHON) -m ruff check $(check_dirs) --fix + $(PYTHON) -m ruff format $(check_dirs) # Run smolagents tests test: - pytest ./tests/ \ No newline at end of file + $(PYTHON) -m pytest ./tests/ From 8cec8bf14d6f1576ca789c13165c7eb8b0eac13a Mon Sep 17 00:00:00 2001 From: sysradium Date: Thu, 13 Feb 2025 23:17:26 +0100 Subject: [PATCH 6/9] add missing mlx dep, as otherwise tests fail locally --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 266717318..91d6fa3c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,4 +111,4 @@ lines-after-imports = 2 [project.scripts] smolagent = "smolagents.cli:main" -webagent = "smolagents.vision_web_browser:main" \ No newline at end of file +webagent = "smolagents.vision_web_browser:main" From 5010fc92f82fae4f716b8013cbfff06c1e04a4b2 Mon Sep 17 00:00:00 2001 From: sysradium Date: Thu, 13 Feb 2025 23:18:07 +0100 Subject: [PATCH 7/9] force MultiAgentsTests to dump data into tmp dir, not into repo --- tests/test_agents.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_agents.py b/tests/test_agents.py index 0e32eebd7..758e40ae4 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -932,6 +932,10 @@ def fake_code_model(messages, stop_sequences=None, grammar=None) -> str: class MultiAgentsTests(unittest.TestCase): + @pytest.fixture(autouse=True) + def initdir(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + def test_multiagents_save(self): model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens=2096, temperature=0.5) From c5b0edf2a3cf3a906d2a194692d3ed481a6b29a0 Mon Sep 17 00:00:00 2001 From: sysradium Date: Fri, 14 Feb 2025 21:34:26 +0100 Subject: [PATCH 8/9] small linter tweaks --- src/smolagents/memory.py | 13 ++++--------- src/smolagents/models.py | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/src/smolagents/memory.py b/src/smolagents/memory.py index ac573a6e0..0c6244b61 100644 --- a/src/smolagents/memory.py +++ b/src/smolagents/memory.py @@ -1,17 +1,12 @@ from dataclasses import asdict, dataclass from logging import getLogger -from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, Union +from typing import Any, List, TypedDict, Union from smolagents.models import ChatMessage, MessageRole from smolagents.monitoring import AgentLogger, LogLevel from smolagents.utils import AgentError, make_json_serializable -if TYPE_CHECKING: - from smolagents.models import ChatMessage - from smolagents.monitoring import AgentLogger - - logger = getLogger(__name__) @@ -42,7 +37,7 @@ class MemoryStep: def dict(self): return asdict(self) - def to_messages(self, **kwargs) -> List[Dict[str, Any]]: + def to_messages(self, **kwargs) -> List[Message]: raise NotImplementedError @@ -55,7 +50,7 @@ class ActionStep(MemoryStep): step_number: int | None = None error: AgentError | None = None duration: float | None = None - model_output_message: ChatMessage = None + model_output_message: ChatMessage | None = None model_output: str | None = None observations: str | None = None observations_images: List[str] | None = None @@ -78,7 +73,7 @@ def dict(self): } def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Message]: - messages = [] + messages: List[Message] = [] if self.model_input_messages is not None and show_model_input_messages: messages.append(Message(role=MessageRole.SYSTEM, content=self.model_input_messages)) if self.model_output is not None and not summary_mode: diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 2a586edfe..2f2a58603 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -105,7 +105,7 @@ def from_hf_api(cls, message, raw) -> "ChatMessage": return cls(role=message.role, content=message.content, tool_calls=tool_calls, raw=raw) @classmethod - def from_dict(cls, data: dict) -> "ChatMessage": + def from_dict(cls, data: dict[str, Any]) -> "ChatMessage": if data.get("tool_calls"): tool_calls = [ ChatMessageToolCall( @@ -123,11 +123,11 @@ def dict(self): def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: if isinstance(arguments, dict): return arguments - else: - try: - return json.loads(arguments) - except Exception: - return arguments + + try: + return json.loads(arguments) + except Exception: + return arguments def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage: @@ -155,7 +155,7 @@ def roles(cls): } -def get_tool_json_schema(tool: Tool) -> Dict: +def get_tool_json_schema(tool: Tool) -> dict[str, Any]: properties = deepcopy(tool.inputs) required = [] for key, value in properties.items(): @@ -298,7 +298,7 @@ def _prepare_completion_kwargs( return completion_kwargs - def get_token_counts(self) -> Dict[str, int]: + def get_token_counts(self) -> dict[str, int | None]: return { "input_token_count": self.last_input_token_count, "output_token_count": self.last_output_token_count, @@ -331,11 +331,11 @@ def __call__( """ pass # To be implemented in child classes! - def to_dict(self) -> Dict: + def to_dict(self) -> dict[str, Any]: """ Converts the model into a JSON-compatible dictionary. """ - model_dictionary = { + model_dictionary: dict[str, Any] = { **self.kwargs, "last_input_token_count": self.last_input_token_count, "last_output_token_count": self.last_output_token_count, From db38e88233070969cf241bf9f7dc02cf6c0e4095 Mon Sep 17 00:00:00 2001 From: sysradium Date: Sat, 15 Feb 2025 16:49:51 +0100 Subject: [PATCH 9/9] explain test case --- pyproject.toml | 1 + tests/test_agents.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 91d6fa3c1..e9f542cbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ test = [ dev = [ "smolagents[quality,test]", "sqlalchemy", # for ./examples + "mlx-lm", ] [tool.pytest.ini_options] diff --git a/tests/test_agents.py b/tests/test_agents.py index 758e40ae4..11f46f409 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -33,7 +33,7 @@ populate_template, ) from smolagents.default_tools import DuckDuckGoSearchTool, FinalAnswerTool, PythonInterpreterTool, VisitWebpageTool -from smolagents.memory import PlanningStep, ActionStep +from smolagents.memory import ActionStep, PlanningStep from smolagents.models import ( ChatMessage, ChatMessageToolCall, @@ -747,7 +747,7 @@ def test_planning_step(self, step, expected_messages_list): for content, expected_content in zip(message["content"], expected_message["content"]): assert content == expected_content - def test_agent_memory_to_messages_suceeds_when_tool_fails_by_obeservation_is_set(self): + def test_agent_memory_to_messages_suceeds_when_tool_fails_but_obeservation_is_set(self): tool = PythonInterpreterTool() def _fake_callback(memory_step: ActionStep, agent: CodeAgent) -> None: @@ -758,8 +758,12 @@ def _fake_callback(memory_step: ActionStep, agent: CodeAgent) -> None: model=fake_no_valid_code_block, step_callbacks=[_fake_callback], ) + + # Perform a task. A tool call will fail since the fake model returns invalid response, but a _fake_callback + # sets an observation agent.run("some task") + # This should not fail, even though no tool call have been recorded, and observation came from a callback for s in agent.memory.steps: s.to_messages()