diff --git a/docs/source/en/conceptual_guides/react.md b/docs/source/en/conceptual_guides/react.md index 417fb8590..b097d5e38 100644 --- a/docs/source/en/conceptual_guides/react.md +++ b/docs/source/en/conceptual_guides/react.md @@ -27,7 +27,7 @@ Initialization: the system prompt is stored in a `SystemPromptStep`, and the use While loop (ReAct loop): -- Use `agent.write_inner_memory_from_logs()` to write the agent logs into a list of LLM-readable [chat messages](https://huggingface.co/docs/transformers/en/chat_templating). +- Use `agent.write_memory_to_messages()` to write the agent logs into a list of LLM-readable [chat messages](https://huggingface.co/docs/transformers/en/chat_templating). - Send these messages to a `Model` object to get its completion. Parse the completion to get the action (a JSON blob for `ToolCallingAgent`, a code snippet for `CodeAgent`). - Execute the action and logs result into memory (an `ActionStep`). - At the end of each step, we run all callback functions defined in `agent.step_callbacks` . diff --git a/docs/source/en/guided_tour.md b/docs/source/en/guided_tour.md index aebb4e23e..cbeaba881 100644 --- a/docs/source/en/guided_tour.md +++ b/docs/source/en/guided_tour.md @@ -189,7 +189,7 @@ agent.run("Could you get me the title of the page at url 'https://huggingface.co Here are a few useful attributes to inspect what happened after a run: - `agent.logs` stores the fine-grained logs of the agent. At every step of the agent's run, everything gets stored in a dictionary that then is appended to `agent.logs`. -- Running `agent.write_inner_memory_from_logs()` creates an inner memory of the agent's logs for the LLM to view, as a list of chat messages. This method goes over each step of the log and only stores what it's interested in as a message: for instance, it will save the system prompt and task in separate messages, then for each step it will store the LLM output as a message, and the tool call output as another message. Use this if you want a higher-level view of what has happened - but not every log will be transcripted by this method. +- Running `agent.write_memory_to_messages()` writes the agent's memory as list of chat messages for the Model to view. This method goes over each step of the log and only stores what it's interested in as a message: for instance, it will save the system prompt and task in separate messages, then for each step it will store the LLM output as a message, and the tool call output as another message. Use this if you want a higher-level view of what has happened - but not every log will be transcripted by this method. ## Tools diff --git a/docs/source/hi/guided_tour.md b/docs/source/hi/guided_tour.md index 24cb71d03..4fb05b750 100644 --- a/docs/source/hi/guided_tour.md +++ b/docs/source/hi/guided_tour.md @@ -142,7 +142,7 @@ agent.run("Could you get me the title of the page at url 'https://huggingface.co रन के बाद क्या हुआ यह जांचने के लिए यहाँ कुछ उपयोगी एट्रिब्यूट्स हैं: - `agent.logs` एजेंट के फाइन-ग्रेन्ड लॉग्स को स्टोर करता है। एजेंट के रन के हर स्टेप पर, सब कुछ एक डिक्शनरी में स्टोर किया जाता है जो फिर `agent.logs` में जोड़ा जाता है। -- `agent.write_inner_memory_from_logs()` चलाने से LLM के लिए एजेंट के लॉग्स की एक इनर मेमोरी बनती है, चैट मैसेज की लिस्ट के रूप में। यह मेथड लॉग के प्रत्येक स्टेप पर जाता है और केवल वही स्टोर करता है जिसमें यह एक मैसेज के रूप में रुचि रखता है: उदाहरण के लिए, यह सिस्टम प्रॉम्प्ट और टास्क को अलग-अलग मैसेज के रूप में सेव करेगा, फिर प्रत्येक स्टेप के लिए यह LLM आउटपुट को एक मैसेज के रूप में और टूल कॉल आउटपुट को दूसरे मैसेज के रूप में स्टोर करेगा। +- `agent.write_memory_to_messages()` चलाने से LLM के लिए एजेंट के लॉग्स की एक इनर मेमोरी बनती है, चैट मैसेज की लिस्ट के रूप में। यह मेथड लॉग के प्रत्येक स्टेप पर जाता है और केवल वही स्टोर करता है जिसमें यह एक मैसेज के रूप में रुचि रखता है: उदाहरण के लिए, यह सिस्टम प्रॉम्प्ट और टास्क को अलग-अलग मैसेज के रूप में सेव करेगा, फिर प्रत्येक स्टेप के लिए यह LLM आउटपुट को एक मैसेज के रूप में और टूल कॉल आउटपुट को दूसरे मैसेज के रूप में स्टोर करेगा। ## टूल्स diff --git a/docs/source/zh/guided_tour.md b/docs/source/zh/guided_tour.md index 537e5948e..aaaf2bf46 100644 --- a/docs/source/zh/guided_tour.md +++ b/docs/source/zh/guided_tour.md @@ -152,7 +152,7 @@ agent.run("Could you get me the title of the page at url 'https://huggingface.co 以下是一些有用的属性,用于检查运行后发生了什么: - `agent.logs` 存储 agent 的细粒度日志。在 agent 运行的每一步,所有内容都会存储在一个字典中,然后附加到 `agent.logs` 中。 -- 运行 `agent.write_inner_memory_from_logs()` 会为 LLM 创建一个 agent 日志的内部内存,作为聊天消息列表。此方法会遍历日志的每一步,并仅存储它感兴趣的内容作为消息:例如,它会将系统提示和任务存储为单独的消息,然后对于每一步,它会将 LLM 输出存储为一条消息,工具调用输出存储为另一条消息。如果您想要更高级别的视图 - 但不是每个日志都会被此方法转录。 +- 运行 `agent.write_memory_to_messages()` 会为 LLM 创建一个 agent 日志的内部内存,作为聊天消息列表。此方法会遍历日志的每一步,并仅存储它感兴趣的内容作为消息:例如,它会将系统提示和任务存储为单独的消息,然后对于每一步,它会将 LLM 输出存储为一条消息,工具调用输出存储为另一条消息。如果您想要更高级别的视图 - 但不是每个日志都会被此方法转录。 ## 工具 diff --git a/pyproject.toml b/pyproject.toml index 93def1bfa..6edf550d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,11 +53,17 @@ mcp = [ openai = [ "openai>=1.58.1" ] +telemetry = [ + "arize-phoenix", + "opentelemetry-sdk", + "opentelemetry-exporter-otlp", + "openinference-instrumentation-smolagents>=0.1.1" +] quality = [ "ruff>=0.9.0", ] all = [ - "smolagents[accelerate,audio,e2b,gradio,litellm,mcp,openai,transformers]", + "smolagents[accelerate,audio,e2b,gradio,litellm,mcp,openai,transformers,telemetry]", ] test = [ "ipython>=8.31.0", # for interactive environment tests diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 12a85693e..74c7c389c 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -17,7 +17,7 @@ import inspect import time from collections import deque -from dataclasses import dataclass +from logging import getLogger from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union from rich import box @@ -27,6 +27,23 @@ from rich.syntax import Syntax from rich.text import Text +from smolagents.logger import ( + AgentLogger, + LogLevel, +) +from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall +from smolagents.types import AgentAudio, AgentImage, handle_agent_output_types +from smolagents.utils import ( + AgentError, + AgentExecutionError, + AgentGenerationError, + AgentMaxStepsError, + AgentParsingError, + parse_code_blobs, + parse_json_tool_call, + truncate_content, +) + from .default_tools import TOOL_MAPPING, FinalAnswerTool from .e2b_executor import E2BExecutor from .local_python_executor import ( @@ -57,62 +74,10 @@ Tool, get_tool_description_with_args, ) -from .types import AgentAudio, AgentImage, AgentType, handle_agent_output_types -from .utils import ( - AgentError, - AgentExecutionError, - AgentGenerationError, - AgentLogger, - AgentMaxStepsError, - AgentParsingError, - LogLevel, - parse_code_blobs, - parse_json_tool_call, - truncate_content, -) - - -@dataclass -class ToolCall: - name: str - arguments: Any - id: str - - -class AgentStepLog: - pass +from .types import AgentType -@dataclass -class ActionStep(AgentStepLog): - agent_memory: List[Dict[str, str]] | None = None - tool_calls: List[ToolCall] | None = None - start_time: float | None = None - end_time: float | None = None - step_number: int | None = None - error: AgentError | None = None - duration: float | None = None - llm_output: str | None = None - observations: str | None = None - observations_images: List[str] | None = None - action_output: Any = None - - -@dataclass -class PlanningStep(AgentStepLog): - plan: str - facts: str - - -@dataclass -class TaskStep(AgentStepLog): - task: str - task_images: List[str] | None = None - - -@dataclass -class SystemPromptStep(AgentStepLog): - system_prompt: str +logger = getLogger(__name__) def get_tool_descriptions(tools: Dict[str, Tool], tool_description_template: str) -> str: @@ -227,157 +192,60 @@ def __init__( self.system_prompt = self.initialize_system_prompt() self.input_messages = None - self.logs = [] self.task = None + self.memory = AgentMemory(system_prompt) self.logger = AgentLogger(level=verbosity_level) self.monitor = Monitor(self.model, self.logger) self.step_callbacks = step_callbacks if step_callbacks is not None else [] self.step_callbacks.append(self.monitor.update_metrics) + @property + def logs(self): + logger.warning( + "The 'logs' attribute is deprecated and will soon be removed. Please use 'self.memory.steps' instead." + ) + return [self.memory.system_prompt] + self.memory.steps + def initialize_system_prompt(self): - self.system_prompt = format_prompt_with_tools( + system_prompt = format_prompt_with_tools( self.tools, self.system_prompt_template, self.tool_description_template, ) - self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents) + system_prompt = format_prompt_with_managed_agents_descriptions(system_prompt, self.managed_agents) + return system_prompt - return self.system_prompt - - def write_inner_memory_from_logs(self, summary_mode: bool = False) -> List[Dict[str, str]]: + def write_memory_to_messages( + self, + summary_mode: Optional[bool] = False, + ) -> List[Dict[str, str]]: """ - Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages - that can be used as input to the LLM. - - Args: - summary_mode (`bool`): Whether to write a summary of the logs or the full logs. + Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages + that can be used as input to the LLM. Adds a number of keywords (such as PLAN, error, etc) to help + the LLM. """ - memory = [] - for i, step_log in enumerate(self.logs): - if isinstance(step_log, SystemPromptStep): - if not summary_mode: - thought_message = { - "role": MessageRole.SYSTEM, - "content": [{"type": "text", "text": step_log.system_prompt.strip()}], - } - memory.append(thought_message) - - elif isinstance(step_log, PlanningStep): - thought_message = { - "role": MessageRole.ASSISTANT, - "content": [{"type": "text", "text": "[FACTS LIST]:\n" + step_log.facts.strip()}], - } - memory.append(thought_message) - - if not summary_mode: - thought_message = { - "role": MessageRole.ASSISTANT, - "content": [{"type": "text", "text": "[PLAN]:\n" + step_log.plan.strip()}], - } - memory.append(thought_message) - - elif isinstance(step_log, TaskStep): - task_message = { - "role": MessageRole.USER, - "content": [{"type": "text", "text": f"New task:\n{step_log.task}"}], - } - if step_log.task_images: - for image in step_log.task_images: - task_message["content"].append({"type": "image", "image": image}) - memory.append(task_message) - - elif isinstance(step_log, ActionStep): - if step_log.llm_output is not None and not summary_mode: - thought_message = { - "role": MessageRole.ASSISTANT, - "content": [{"type": "text", "text": step_log.llm_output.strip()}], - } - memory.append(thought_message) - if step_log.tool_calls is not None: - tool_call_message = { - "role": MessageRole.ASSISTANT, - "content": [ - { - "type": "text", - "text": str( - [ - { - "id": tool_call.id, - "type": "function", - "function": { - "name": tool_call.name, - "arguments": tool_call.arguments, - }, - } - for tool_call in step_log.tool_calls - ] - ), - } - ], - } - memory.append(tool_call_message) - if step_log.error is not None: - error_message = { - "role": MessageRole.ASSISTANT, - "content": [ - { - "type": "text", - "text": ( - "Error:\n" - + str(step_log.error) - + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" - ), - } - ], - } - memory.append(error_message) - if step_log.observations is not None: - if step_log.tool_calls: - tool_call_reference = f"Call id: {(step_log.tool_calls[0].id if getattr(step_log.tool_calls[0], 'id') else 'call_0')}\n" - else: - tool_call_reference = "" - text_observations = f"Observation:\n{step_log.observations}" - tool_response_message = { - "role": MessageRole.TOOL_RESPONSE, - "content": [{"type": "text", "text": tool_call_reference + text_observations}], - } - memory.append(tool_response_message) - if step_log.observations_images: - thought_message_image = { - "role": MessageRole.USER, - "content": [{"type": "text", "text": "Here are the observed images:"}] - + [ - { - "type": "image", - "image": image, - } - for image in step_log.observations_images - ], - } - memory.append(thought_message_image) + messages = self.memory.system_prompt.to_messages(summary_mode=summary_mode) + for step_log in self.memory.steps: + messages.extend(step_log.to_messages(summary_mode=summary_mode)) + return messages - return memory - - def get_succinct_logs(self): - return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs] - - def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]: + def extract_action(self, model_output: str, split_token: str) -> Tuple[str, str]: """ Parse action from the LLM output Args: - llm_output (`str`): Output of the LLM + model_output (`str`): Output of the LLM split_token (`str`): Separator for the action. Should match the example in the system prompt. """ try: - split = llm_output.split(split_token) + split = model_output.split(split_token) rationale, action = ( split[-2], split[-1], ) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output except Exception: raise AgentParsingError( - f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!", + f"No '{split_token}' token provided in your output.\nYour output:\n{model_output}\n. Be sure to include an action, prefaced with '{split_token}'!", self.logger, ) return rationale.strip(), action.strip() @@ -393,16 +261,17 @@ def provide_final_answer(self, task: str, images: Optional[list[str]]) -> str: Returns: `str`: Final answer to the task. """ + messages = [{"role": MessageRole.SYSTEM, "content": []}] if images: - self.input_messages[0]["content"] = [ + messages[0]["content"] = [ { "type": "text", "text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", } ] - self.input_messages[0]["content"].append({"type": "image"}) - self.input_messages += self.write_inner_memory_from_logs()[1:] - self.input_messages += [ + messages[0]["content"].append({"type": "image"}) + messages += self.write_memory_to_messages()[1:] + messages += [ { "role": MessageRole.USER, "content": [ @@ -414,14 +283,14 @@ def provide_final_answer(self, task: str, images: Optional[list[str]]) -> str: } ] else: - self.input_messages[0]["content"] = [ + messages[0]["content"] = [ { "type": "text", "text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", } ] - self.input_messages += self.write_inner_memory_from_logs()[1:] - self.input_messages += [ + messages += self.write_memory_to_messages()[1:] + messages += [ { "role": MessageRole.USER, "content": [ @@ -433,7 +302,8 @@ def provide_final_answer(self, task: str, images: Optional[list[str]]) -> str: } ] try: - return self.model(self.input_messages).content + chat_message: ChatMessage = self.model(messages) + return chat_message.content except Exception as e: return f"Error in generating final LLM output:\n{e}" @@ -523,18 +393,11 @@ def run( You have been provided with these additional arguments, that you can access using the keys as variables in your python code: {str(additional_args)}.""" - self.initialize_system_prompt() - system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt) - + self.system_prompt = self.initialize_system_prompt() + self.memory.system_prompt = SystemPromptStep(system_prompt=self.system_prompt) if reset: - self.logs = [] - self.logs.append(system_prompt_step) + self.memory.reset() self.monitor.reset() - else: - if len(self.logs) > 0: - self.logs[0] = system_prompt_step - else: - self.logs.append(system_prompt_step) self.logger.log( Panel( @@ -547,7 +410,7 @@ def run( level=LogLevel.INFO, ) - self.logs.append(TaskStep(task=self.task, task_images=images)) + self.memory.steps.append(TaskStep(task=self.task, task_images=images)) if single_step: step_start_time = time.time() step_log = ActionStep(start_time=step_start_time, observations_images=images) @@ -604,7 +467,7 @@ def _run(self, task: str, images: List[str] | None = None) -> Generator[ActionSt finally: step_log.end_time = time.time() step_log.duration = step_log.end_time - step_start_time - self.logs.append(step_log) + self.memory.steps.append(step_log) for callback in self.step_callbacks: # For compatibility with old callbacks that don't take the agent as an argument if len(inspect.signature(callback).parameters) == 1: @@ -616,15 +479,16 @@ def _run(self, task: str, images: List[str] | None = None) -> Generator[ActionSt if final_answer is None and self.step_number == self.max_steps: error_message = "Reached max steps." + final_answer = self.provide_final_answer(task, images) final_step_log = ActionStep( step_number=self.step_number, error=AgentMaxStepsError(error_message, self.logger) ) - self.logs.append(final_step_log) - final_answer = self.provide_final_answer(task, images) - self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO) + self.logger.log(final_step_log) + final_step_log = ActionStep(error=AgentMaxStepsError(error_message, self.logger)) final_step_log.action_output = final_answer final_step_log.end_time = time.time() final_step_log.duration = step_log.end_time - step_start_time + self.memory.steps.append(final_step_log) for callback in self.step_callbacks: # For compatibility with old callbacks that don't take the agent as an argument if len(inspect.signature(callback).parameters) == 1: @@ -658,7 +522,8 @@ def planning_step(self, task, is_first_step: bool, step: int) -> None: Now begin!""", } - answer_facts = self.model([message_prompt_facts, message_prompt_task]).content + chat_message_facts: ChatMessage = self.model([message_prompt_facts, message_prompt_task]) + answer_facts = chat_message_facts.content message_system_prompt_plan = { "role": MessageRole.SYSTEM, @@ -673,10 +538,11 @@ def planning_step(self, task, is_first_step: bool, step: int) -> None: answer_facts=answer_facts, ), } - answer_plan = self.model( + chat_message_plan: ChatMessage = self.model( [message_system_prompt_plan, message_user_prompt_plan], stop_sequences=[""], - ).content + ) + answer_plan = chat_message_plan.content final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task: ``` @@ -686,14 +552,21 @@ def planning_step(self, task, is_first_step: bool, step: int) -> None: ``` {answer_facts} ```""".strip() - self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)) + self.memory.steps.append( + PlanningStep( + plan=final_plan_redaction, + facts=final_facts_redaction, + model_output_message_plan=chat_message_plan, + model_output_message_facts=chat_message_facts, + ) + ) self.logger.log( Rule("[bold]Initial plan", style="orange"), Text(final_plan_redaction), level=LogLevel.INFO, ) else: # update plan - agent_memory = self.write_inner_memory_from_logs( + memory_messages = self.write_memory_to_messages( summary_mode=False ) # This will not log the plan but will log facts @@ -706,7 +579,10 @@ def planning_step(self, task, is_first_step: bool, step: int) -> None: "role": MessageRole.USER, "content": [{"type": "text", "text": USER_PROMPT_FACTS_UPDATE}], } - facts_update = self.model([facts_update_system_prompt] + agent_memory + [facts_update_message]).content + chat_message_facts: ChatMessage = self.model( + [facts_update_system_prompt] + memory_messages + [facts_update_message] + ) + facts_update = chat_message_facts.content # Redact updated plan plan_update_message = { @@ -728,18 +604,27 @@ def planning_step(self, task, is_first_step: bool, step: int) -> None: } ], } - plan_update = self.model( - [plan_update_message] + agent_memory + [plan_update_message_user], + chat_message_plan: ChatMessage = self.model( + [plan_update_message] + memory_messages + [plan_update_message_user], stop_sequences=[""], - ).content + ) # Log final facts and plan - final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update) + final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format( + task=task, plan_update=chat_message_plan.content + ) final_facts_redaction = f"""Here is the updated list of the facts that I know: ``` {facts_update} ```""" - self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)) + self.memory.steps.append( + PlanningStep( + plan=final_plan_redaction, + facts=final_facts_redaction, + model_output_message_plan=chat_message_plan, + model_output_message_facts=chat_message_facts, + ) + ) self.logger.log( Rule("[bold]Updated plan", style="orange"), Text(final_plan_redaction), @@ -783,19 +668,20 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]: Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. Returns None if the step is not final. """ - agent_memory = self.write_inner_memory_from_logs() + memory_messages = self.write_memory_to_messages() - self.input_messages = agent_memory + self.input_messages = memory_messages # Add new step in logs - log_entry.agent_memory = agent_memory.copy() + log_entry.model_input_messages = memory_messages.copy() try: - model_message = self.model( - self.input_messages, + model_message: ChatMessage = self.model( + memory_messages, tools_to_call_from=list(self.tools.values()), stop_sequences=["Observation:"], ) + log_entry.model_output_message = model_message if model_message.tool_calls is None or len(model_message.tool_calls) == 0: raise Exception("Model did not call any tools. Call `final_answer` tool to return a final answer.") tool_call = model_message.tool_calls[0] @@ -931,7 +817,7 @@ def __init__( ) def initialize_system_prompt(self): - super().initialize_system_prompt() + self.system_prompt = super().initialize_system_prompt() self.system_prompt = self.system_prompt.replace( "{{authorized_imports}}", ( @@ -947,20 +833,22 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]: Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. Returns None if the step is not final. """ - agent_memory = self.write_inner_memory_from_logs() + memory_messages = self.write_memory_to_messages() - self.input_messages = agent_memory.copy() + self.input_messages = memory_messages.copy() # Add new step in logs - log_entry.agent_memory = agent_memory.copy() + log_entry.model_input_messages = memory_messages.copy() try: additional_args = {"grammar": self.grammar} if self.grammar is not None else {} - llm_output = self.model( + chat_message: ChatMessage = self.model( self.input_messages, stop_sequences=["", "Observation:"], **additional_args, - ).content - log_entry.llm_output = llm_output + ) + log_entry.model_output_message = chat_message + model_output = chat_message.content + log_entry.model_output = model_output except Exception as e: raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e @@ -972,7 +860,7 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]: style="orange", ), Syntax( - llm_output, + model_output, lexer="markdown", theme="github-dark", word_wrap=True, @@ -983,7 +871,7 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]: # Parse try: - code_action = fix_final_answer_code(parse_code_blobs(llm_output)) + code_action = fix_final_answer_code(parse_code_blobs(model_output)) except Exception as e: error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs." raise AgentParsingError(error_msg, self.logger) @@ -992,7 +880,7 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]: ToolCall( name="python_interpreter", arguments=code_action, - id=f"call_{len(self.logs)}", + id=f"call_{len(self.memory.steps)}", ) ] @@ -1095,7 +983,7 @@ def __call__(self, request, **kwargs): answer = f"Here is the final answer from your managed agent '{self.name}':\n" answer += str(output) answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n" - for message in self.agent.write_inner_memory_from_logs(summary_mode=True): + for message in self.agent.write_memory_to_messages(summary_mode=True): content = message["content"] answer += "\n" + truncate_content(str(content)) + "\n---" answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'." @@ -1104,9 +992,4 @@ def __call__(self, request, **kwargs): return output -__all__ = [ - "ManagedAgent", - "MultiStepAgent", - "CodeAgent", - "ToolCallingAgent", -] +__all__ = ["ManagedAgent", "MultiStepAgent", "CodeAgent", "ToolCallingAgent", "AgentMemory"] diff --git a/src/smolagents/e2b_executor.py b/src/smolagents/e2b_executor.py index 3d3eee1e5..fe87011fa 100644 --- a/src/smolagents/e2b_executor.py +++ b/src/smolagents/e2b_executor.py @@ -74,14 +74,16 @@ def __init__(self, additional_imports: List[str], tools: List[Tool], logger): tool_codes.append(tool_code) tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES]) - tool_definition_code += textwrap.dedent(""" + tool_definition_code += textwrap.dedent( + """ class Tool: def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) def forward(self, *args, **kwargs): pass # to be implemented in child class - """) + """ + ) tool_definition_code += "\n\n".join(tool_codes) tool_definition_execution = self.run_code_raise_errors(tool_definition_code) diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py index d7d1211e8..bc367aabd 100644 --- a/src/smolagents/gradio_ui.py +++ b/src/smolagents/gradio_ui.py @@ -19,13 +19,14 @@ import shutil from typing import Optional -from .agents import ActionStep, AgentStepLog, MultiStepAgent -from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types -from .utils import _is_package_available +from smolagents.agents import ActionStep, MultiStepAgent +from smolagents.memory import MemoryStep +from smolagents.types import AgentAudio, AgentImage, AgentText, handle_agent_output_types +from smolagents.utils import _is_package_available def pull_messages_from_step( - step_log: AgentStepLog, + step_log: MemoryStep, ): """Extract ChatMessage objects from agent steps with proper nesting""" import gradio as gr @@ -36,15 +37,15 @@ def pull_messages_from_step( yield gr.ChatMessage(role="assistant", content=f"**{step_number}**") # First yield the thought/reasoning from the LLM - if hasattr(step_log, "llm_output") and step_log.llm_output is not None: + if hasattr(step_log, "model_output") and step_log.model_output is not None: # Clean up the LLM output - llm_output = step_log.llm_output.strip() + model_output = step_log.model_output.strip() # Remove any trailing and extra backticks, handling multiple possible formats - llm_output = re.sub(r"```\s*", "```", llm_output) # handles ``` - llm_output = re.sub(r"\s*```", "```", llm_output) # handles ``` - llm_output = re.sub(r"```\s*\n\s*", "```", llm_output) # handles ```\n - llm_output = llm_output.strip() - yield gr.ChatMessage(role="assistant", content=llm_output) + model_output = re.sub(r"```\s*", "```", model_output) # handles ``` + model_output = re.sub(r"\s*```", "```", model_output) # handles ``` + model_output = re.sub(r"```\s*\n\s*", "```", model_output) # handles ```\n + model_output = model_output.strip() + yield gr.ChatMessage(role="assistant", content=model_output) # For tool calls, create a parent message if hasattr(step_log, "tool_calls") and step_log.tool_calls is not None: diff --git a/src/smolagents/logger.py b/src/smolagents/logger.py new file mode 100644 index 000000000..c74f6b101 --- /dev/null +++ b/src/smolagents/logger.py @@ -0,0 +1,85 @@ +import json +from enum import IntEnum +from typing import TYPE_CHECKING + +from rich.console import Console +from rich.rule import Rule +from rich.syntax import Syntax + + +if TYPE_CHECKING: + from smolagents.memory import AgentMemory + + +class LogLevel(IntEnum): + ERROR = 0 # Only errors + INFO = 1 # Normal output (default) + DEBUG = 2 # Detailed output + + +class AgentLogger: + def __init__(self, level: LogLevel = LogLevel.INFO): + self.level = level + self.console = Console() + + def log(self, *args, level: str | LogLevel = LogLevel.INFO, **kwargs): + """Logs a message to the console. + + Args: + level (LogLevel, optional): Defaults to LogLevel.INFO. + """ + if isinstance(level, str): + level = LogLevel[level.upper()] + if level <= self.level: + self.console.print(*args, **kwargs) + + def replay(self, agent_memory: "AgentMemory", full: bool = False): + """Prints a pretty replay of the agent's steps. + + Args: + with_memory (bool, optional): If True, also displays the memory at each step. Defaults to False. + Careful: will increase log length exponentially. Use only for debugging. + """ + memory = [] + for step_log in agent_memory.steps: + memory.extend(step_log.to_messages(show_model_input_messages=full)) + + self.console.log("Replaying the agent's steps:") + ix = 0 + for step in memory: + role = step["role"].strip() + if ix > 0 and role == "system": + role == "memory" + theme = "default" + match role: + case "assistant": + theme = "monokai" + ix += 1 + case "system": + theme = "monokai" + case "tool-response": + theme = "github_dark" + + content = step["content"] + try: + content = eval(content) + except Exception: + content = [step["content"]] + + for substep_ix, item in enumerate(content): + self.console.log( + Rule( + f"{role.upper()}, STEP {ix}, SUBSTEP {substep_ix + 1}/{len(content)}", + align="center", + style="orange", + ), + Syntax( + json.dumps(item, indent=4) if isinstance(item, dict) else str(item), + lexer="json", + theme=theme, + word_wrap=True, + ), + ) + + +__all__ = ["AgentLogger"] diff --git a/src/smolagents/memory.py b/src/smolagents/memory.py new file mode 100644 index 000000000..1a32135ad --- /dev/null +++ b/src/smolagents/memory.py @@ -0,0 +1,194 @@ +from dataclasses import asdict, dataclass +from logging import getLogger +from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, Union + +from smolagents.models import ChatMessage, MessageRole +from smolagents.utils import AgentError, make_json_serializable + + +if TYPE_CHECKING: + from smolagents.models import ChatMessage + + +logger = getLogger(__name__) + + +class Message(TypedDict): + role: MessageRole + content: str | list[dict] + + +@dataclass +class ToolCall: + name: str + arguments: Any + id: str + + def dict(self): + return { + "id": self.id, + "type": "function", + "function": { + "name": self.name, + "arguments": make_json_serializable(self.arguments), + }, + } + + +class MemoryStep: + raw: Any # This is a placeholder for the raw data that the agent logs + + def dict(self): + return asdict(self) + + def to_messages(self, **kwargs) -> List[Dict[str, Any]]: + raise NotImplementedError + + +@dataclass +class ActionStep(MemoryStep): + model_input_messages: List[Dict[str, str]] | None = None + tool_calls: List[ToolCall] | None = None + start_time: float | None = None + end_time: float | None = None + step_number: int | None = None + error: AgentError | None = None + duration: float | None = None + model_output_message: ChatMessage = None + model_output: str | None = None + observations: str | None = None + observations_images: List[str] | None = None + action_output: Any = None + + def dict(self): + # We overwrite the method to parse the tool_calls and action_output manually + return { + "model_input_messages": self.model_input_messages, + "tool_calls": [tc.dict() for tc in self.tool_calls] if self.tool_calls else [], + "start_time": self.start_time, + "end_time": self.end_time, + "step": self.step_number, + "error": self.error.dict() if self.error else None, + "duration": self.duration, + "model_output_message": self.model_output_message, + "model_output": self.model_output, + "observations": self.observations, + "action_output": make_json_serializable(self.action_output), + } + + def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Dict[str, Any]]: + messages = [] + if self.model_input_messages is not None and show_model_input_messages: + messages.append(Message(role=MessageRole.SYSTEM, content=self.model_input_messages)) + if self.model_output is not None and not summary_mode: + messages.append( + Message(role=MessageRole.ASSISTANT, content=[{"type": "text", "text": self.model_output.strip()}]) + ) + + if self.tool_calls is not None: + messages.append( + Message( + role=MessageRole.ASSISTANT, + content=[{"type": "text", "text": str([tc.dict() for tc in self.tool_calls])}], + ) + ) + + if self.error is not None: + message_content = ( + "Error:\n" + + str(self.error) + + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" + ) + if self.tool_calls is None: + tool_response_message = Message( + role=MessageRole.ASSISTANT, content=[{"type": "text", "text": message_content}] + ) + else: + tool_response_message = Message( + role=MessageRole.TOOL_RESPONSE, content=f"Call id: {self.tool_calls[0].id}\n{message_content}" + ) + + messages.append(tool_response_message) + else: + if self.observations is not None and self.tool_calls is not None: + messages.append( + Message( + role=MessageRole.TOOL_RESPONSE, + content=f"Call id: {self.tool_calls[0].id}\nObservation:\n{self.observations}", + ) + ) + if self.observations_images: + messages.append( + Message( + role=MessageRole.USER, + content=[{"type": "text", "text": "Here are the observed images:"}] + + [ + { + "type": "image", + "image": image, + } + for image in self.observations_images + ], + ) + ) + return messages + + +@dataclass +class PlanningStep(MemoryStep): + model_output_message_facts: ChatMessage + facts: str + model_output_message_facts: ChatMessage + plan: str + + def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]: + messages = [] + messages.append(Message(role=MessageRole.ASSISTANT, content=f"[FACTS LIST]:\n{self.facts.strip()}")) + + if not summary_mode: + messages.append(Message(role=MessageRole.ASSISTANT, content=f"[PLAN]:\n{self.plan.strip()}")) + return messages + + +@dataclass +class TaskStep(MemoryStep): + task: str + task_images: List[str] | None = None + + def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]: + content = [{"type": "text", "text": f"New task:\n{self.task}"}] + if self.task_images: + for image in self.task_images: + content.append({"type": "image", "image": image}) + + return [Message(role=MessageRole.USER, content=content)] + + +@dataclass +class SystemPromptStep(MemoryStep): + system_prompt: str + + def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]: + if summary_mode: + return [] + return [Message(role=MessageRole.SYSTEM, content=[{"type": "text", "text": self.system_prompt.strip()}])] + + +class AgentMemory: + def __init__(self, system_prompt: str): + self.system_prompt = SystemPromptStep(system_prompt=system_prompt) + self.steps: List[Union[TaskStep, ActionStep, PlanningStep]] = [] + + def reset(self): + self.steps = [] + + def get_succinct_steps(self) -> list[dict]: + return [ + {key: value for key, value in step.dict().items() if key != "model_input_messages"} for step in self.steps + ] + + def get_full_steps(self) -> list[dict]: + return [step.dict() for step in self.steps] + + +__all__ = ["AgentMemory"] diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 063069270..2b91d9be2 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -47,10 +47,10 @@ } -def get_dict_from_nested_dataclasses(obj): +def get_dict_from_nested_dataclasses(obj, ignore_key=None): def convert(obj): if hasattr(obj, "__dataclass_fields__"): - return {k: convert(v) for k, v in asdict(obj).items()} + return {k: convert(v) for k, v in asdict(obj).items() if k != ignore_key} return obj return convert(obj) @@ -78,7 +78,7 @@ class ChatMessageToolCall: type: str @classmethod - def from_hf_api(cls, tool_call) -> "ChatMessageToolCall": + def from_hf_api(cls, tool_call, raw) -> "ChatMessageToolCall": return cls( function=ChatMessageToolCallDefinition.from_hf_api(tool_call.function), id=tool_call.id, @@ -91,16 +91,17 @@ class ChatMessage: role: str content: Optional[str] = None tool_calls: Optional[List[ChatMessageToolCall]] = None + raw: Optional[Any] = None # Stores the raw output from the API def model_dump_json(self): - return json.dumps(get_dict_from_nested_dataclasses(self)) + return json.dumps(get_dict_from_nested_dataclasses(self, ignore_key="raw")) @classmethod - def from_hf_api(cls, message) -> "ChatMessage": + def from_hf_api(cls, message, raw) -> "ChatMessage": tool_calls = None if getattr(message, "tool_calls", None) is not None: tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls] - return cls(role=message.role, content=message.content, tool_calls=tool_calls) + return cls(role=message.role, content=message.content, tool_calls=tool_calls, raw=raw) @classmethod def from_dict(cls, data: dict) -> "ChatMessage": @@ -114,6 +115,9 @@ def from_dict(cls, data: dict) -> "ChatMessage": data["tool_calls"] = tool_calls return cls(**data) + def dict(self): + return json.dumps(get_dict_from_nested_dataclasses(self)) + def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: if isinstance(arguments, dict): @@ -395,7 +399,7 @@ def __call__( self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens - message = ChatMessage.from_hf_api(response.choices[0].message) + message = ChatMessage.from_hf_api(response.choices[0].message, raw=response) if tools_to_call_from is not None: return parse_tool_args_if_needed(message) return message @@ -587,7 +591,11 @@ def __call__( output = remove_stop_sequences(output, stop_sequences) if tools_to_call_from is None: - return ChatMessage(role="assistant", content=output) + return ChatMessage( + role="assistant", + content=output, + raw={"out": out, "completion_kwargs": completion_kwargs}, + ) else: if "Action:" in output: output = output.split("Action:", 1)[1].strip() @@ -614,6 +622,7 @@ def __call__( function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments), ) ], + raw={"out": out, "completion_kwargs": completion_kwargs}, ) @@ -678,10 +687,10 @@ def __call__( self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens - message = ChatMessage.from_dict( response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}) ) + message.raw = response if tools_to_call_from is not None: return parse_tool_args_if_needed(message) @@ -762,6 +771,7 @@ def __call__( message = ChatMessage.from_dict( response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}) ) + message.raw = response if tools_to_call_from is not None: return parse_tool_args_if_needed(message) return message diff --git a/src/smolagents/monitoring.py b/src/smolagents/monitoring.py index 59f43f443..d56eef4dc 100644 --- a/src/smolagents/monitoring.py +++ b/src/smolagents/monitoring.py @@ -41,7 +41,7 @@ def update_metrics(self, step_log): """Update the metrics of the monitor. Args: - step_log ([`AgentStepLog`]): Step log to update the monitor with. + step_log ([`MemoryStep`]): Step log to update the monitor with. """ step_duration = step_log.duration self.step_durations.append(step_duration) diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py index 8aa631f1a..f50c37f0e 100644 --- a/src/smolagents/utils.py +++ b/src/smolagents/utils.py @@ -23,12 +23,13 @@ import re import textwrap import types -from enum import IntEnum from functools import lru_cache from io import BytesIO -from typing import Dict, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Tuple, Union -from rich.console import Console + +if TYPE_CHECKING: + from smolagents.memory import AgentLogger __all__ = ["AgentError"] @@ -48,8 +49,6 @@ def _is_pillow_available(): return importlib.util.find_spec("PIL") is not None -console = Console() - BASE_BUILTIN_MODULES = [ "collections", "datetime", @@ -65,29 +64,16 @@ def _is_pillow_available(): ] -class LogLevel(IntEnum): - ERROR = 0 # Only errors - INFO = 1 # Normal output (default) - DEBUG = 2 # Detailed output - - -class AgentLogger: - def __init__(self, level: LogLevel = LogLevel.INFO): - self.level = level - self.console = Console() - - def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs): - if level <= self.level: - self.console.print(*args, **kwargs) - - class AgentError(Exception): """Base class for other agent-related exceptions""" - def __init__(self, message, logger: AgentLogger): + def __init__(self, message, logger: "AgentLogger"): super().__init__(message) self.message = message - logger.log(f"[bold red]{message}[/bold red]", level=LogLevel.ERROR) + logger.log(f"[bold red]{message}[/bold red]", level="ERROR") + + def dict(self) -> Dict[str, str]: + return {"type": self.__class__.__name__, "message": str(self.message)} class AgentParsingError(AgentError): @@ -114,6 +100,32 @@ class AgentGenerationError(AgentError): pass +def make_json_serializable(obj: Any) -> Any: + """Recursive function to make objects JSON serializable""" + if obj is None: + return None + elif isinstance(obj, (str, int, float, bool)): + # Try to parse string as JSON if it looks like a JSON object/array + if isinstance(obj, str): + try: + if (obj.startswith("{") and obj.endswith("}")) or (obj.startswith("[") and obj.endswith("]")): + parsed = json.loads(obj) + return make_json_serializable(parsed) + except json.JSONDecodeError: + pass + return obj + elif isinstance(obj, (list, tuple)): + return [make_json_serializable(item) for item in obj] + elif isinstance(obj, dict): + return {str(k): make_json_serializable(v) for k, v in obj.items()} + elif hasattr(obj, "__dict__"): + # For custom objects, convert their __dict__ to a serializable format + return {"_type": obj.__class__.__name__, **{k: make_json_serializable(v) for k, v in obj.__dict__.items()}} + else: + # For any other type, convert to string + return str(obj) + + def parse_json_blob(json_blob: str) -> Dict[str, str]: try: first_accolade_index = json_blob.find("{") diff --git a/tests/test_agents.py b/tests/test_agents.py index d4cbda22c..8083d2a09 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -303,9 +303,9 @@ def test_fake_toolcalling_agent(self): output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, str) assert "7.2904" in output - assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" - assert "7.2904" in agent.logs[2].observations - assert agent.logs[3].llm_output is None + assert agent.memory.steps[0].task == "What is 2 multiplied by 3.6452?" + assert "7.2904" in agent.memory.steps[1].observations + assert agent.memory.steps[2].model_output is None def test_toolcalling_agent_handles_image_tool_outputs(self): from PIL import Image @@ -348,9 +348,9 @@ def test_fake_code_agent(self): output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, float) assert output == 7.2904 - assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" - assert agent.logs[3].tool_calls == [ - ToolCall(name="python_interpreter", arguments="final_answer(7.2904)", id="call_3") + assert agent.memory.steps[0].task == "What is 2 multiplied by 3.6452?" + assert agent.memory.steps[2].tool_calls == [ + ToolCall(name="python_interpreter", arguments="final_answer(7.2904)", id="call_2") ] def test_additional_args_added_to_task(self): @@ -366,30 +366,30 @@ def test_reset_conversations(self): agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model) output = agent.run("What is 2 multiplied by 3.6452?", reset=True) assert output == 7.2904 - assert len(agent.logs) == 4 + assert len(agent.memory.steps) == 3 output = agent.run("What is 2 multiplied by 3.6452?", reset=False) assert output == 7.2904 - assert len(agent.logs) == 6 + assert len(agent.memory.steps) == 5 output = agent.run("What is 2 multiplied by 3.6452?", reset=True) assert output == 7.2904 - assert len(agent.logs) == 4 + assert len(agent.memory.steps) == 3 def test_code_agent_code_errors_show_offending_line_and_error(self): agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_error) output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, AgentText) assert output == "got an error" - assert "Code execution failed at line 'error_function()'" in str(agent.logs[2].error) - assert "ValueError" in str(agent.logs) + assert "Code execution failed at line 'error_function()'" in str(agent.memory.steps[1].error) + assert "ValueError" in str(agent.memory.steps) def test_code_agent_syntax_error_show_offending_lines(self): agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error) output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, AgentText) assert output == "got an error" - assert ' print("Failing due to unexpected indent")' in str(agent.logs) + assert ' print("Failing due to unexpected indent")' in str(agent.memory.steps) def test_setup_agent_with_empty_toolbox(self): ToolCallingAgent(model=FakeToolCallModel(), tools=[]) @@ -401,8 +401,8 @@ def test_fails_max_steps(self): max_steps=5, ) answer = agent.run("What is 2 multiplied by 3.6452?") - assert len(agent.logs) == 8 - assert type(agent.logs[-1].error) is AgentMaxStepsError + assert len(agent.memory.steps) == 7 + assert type(agent.memory.steps[-1].error) is AgentMaxStepsError assert isinstance(answer, str) def test_tool_descriptions_get_baked_in_system_prompt(self): @@ -638,4 +638,9 @@ def get_weather(location: str, celsius: bool = False) -> str: ) agent = ToolCallingAgent(model=model, tools=[get_weather], max_steps=1) agent.run("What's the weather in Paris?") - assert agent.logs[2].tool_calls[0].name == "get_weather" + assert agent.memory.steps[0].task == "What's the weather in Paris?" + assert agent.memory.steps[1].tool_calls[0].name == "get_weather" + step_memory_dict = agent.memory.get_succinct_steps()[1] + assert step_memory_dict["model_output_message"].tool_calls[0].function.name == "get_weather" + assert step_memory_dict["model_output_message"].raw["completion_kwargs"]["max_new_tokens"] == 100 + assert "model_input_messages" in agent.memory.get_full_steps()[1] diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py index abe174015..1cf8d522e 100644 --- a/tests/test_monitoring.py +++ b/tests/test_monitoring.py @@ -22,12 +22,12 @@ ToolCallingAgent, stream_to_gradio, ) +from smolagents.logger import AgentLogger, LogLevel from smolagents.models import ( ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, ) -from smolagents.utils import AgentLogger, LogLevel class FakeLLMModel: