Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regressions, refactor logging and improve replay function #419

Merged
merged 5 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/smolagents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from .e2b_executor import *
from .gradio_ui import *
from .local_python_executor import *
from .logger import *
from .memory import *
from .models import *
from .monitoring import *
Expand Down
81 changes: 26 additions & 55 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@
from logging import getLogger
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

from rich import box
from rich.console import Group
from rich.panel import Panel
from rich.rule import Rule
from rich.syntax import Syntax
from rich.text import Text

from smolagents.logger import (
from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall
from smolagents.monitoring import (
YELLOW_HEX,
AgentLogger,
LogLevel,
)
from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall
from smolagents.types import AgentAudio, AgentImage, handle_agent_output_types
from smolagents.utils import (
AgentError,
Expand Down Expand Up @@ -123,9 +122,6 @@ def format_prompt_with_managed_agents_descriptions(
return prompt_template.replace(agent_descriptions_placeholder, "")


YELLOW_HEX = "#d4b702"


class MultiStepAgent:
"""
Agent class that solves the given task step by step, using the ReAct framework:
Expand Down Expand Up @@ -399,14 +395,9 @@ def run(
self.memory.reset()
self.monitor.reset()

self.logger.log(
Panel(
f"\n[bold]{self.task.strip()}\n",
title="[bold]New run",
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}",
border_style=YELLOW_HEX,
subtitle_align="left",
),
self.logger.log_task(
content=self.task.strip(),
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}",
level=LogLevel.INFO,
)

Expand Down Expand Up @@ -451,14 +442,7 @@ def _run(self, task: str, images: List[str] | None = None) -> Generator[ActionSt
is_first_step=(self.step_number == 0),
step=self.step_number,
)
self.logger.log(
Rule(
f"[bold]Step {self.step_number}",
characters="━",
style=YELLOW_HEX,
),
level=LogLevel.INFO,
)
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)

# Run one step!
final_answer = self.step(memory_step)
Expand Down Expand Up @@ -520,8 +504,9 @@ def planning_step(self, task, is_first_step: bool, step: int) -> None:
```
Now begin!""",
}
input_messages = [message_prompt_facts, message_prompt_task]

chat_message_facts: ChatMessage = self.model([message_prompt_facts, message_prompt_task])
chat_message_facts: ChatMessage = self.model(input_messages)
answer_facts = chat_message_facts.content

message_system_prompt_plan = {
Expand Down Expand Up @@ -553,6 +538,7 @@ def planning_step(self, task, is_first_step: bool, step: int) -> None:
```""".strip()
self.memory.steps.append(
PlanningStep(
model_input_messages=input_messages,
plan=final_plan_redaction,
facts=final_facts_redaction,
model_output_message_plan=chat_message_plan,
Expand All @@ -578,9 +564,8 @@ def planning_step(self, task, is_first_step: bool, step: int) -> None:
"role": MessageRole.USER,
"content": [{"type": "text", "text": USER_PROMPT_FACTS_UPDATE}],
}
chat_message_facts: ChatMessage = self.model(
[facts_update_system_prompt] + memory_messages + [facts_update_message]
)
input_messages = [facts_update_system_prompt] + memory_messages + [facts_update_message]
chat_message_facts: ChatMessage = self.model(input_messages)
facts_update = chat_message_facts.content

# Redact updated plan
Expand Down Expand Up @@ -618,6 +603,7 @@ def planning_step(self, task, is_first_step: bool, step: int) -> None:
```"""
self.memory.steps.append(
PlanningStep(
model_input_messages=input_messages,
plan=final_plan_redaction,
facts=final_facts_redaction,
model_output_message_plan=chat_message_plan,
Expand All @@ -630,6 +616,15 @@ def planning_step(self, task, is_first_step: bool, step: int) -> None:
level=LogLevel.INFO,
)

def replay(self, detailed: bool = False):
"""Prints a pretty replay of the agent's steps.

Args:
detailed (bool, optional): If True, also displays the memory at each step. Defaults to False.
Careful: will increase log length exponentially. Use only for debugging.
"""
self.memory.replay(self.logger, detailed=detailed)


class ToolCallingAgent(MultiStepAgent):
"""
Expand Down Expand Up @@ -851,20 +846,9 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
except Exception as e:
raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e

self.logger.log(
Group(
Rule(
"[italic]Output message of the LLM:",
align="left",
style="orange",
),
Syntax(
model_output,
lexer="markdown",
theme="github-dark",
word_wrap=True,
),
),
self.logger.log_markdown(
content=model_output,
title="Output message of the LLM:",
level=LogLevel.DEBUG,
)

Expand All @@ -884,20 +868,7 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
]

# Execute
self.logger.log(
Panel(
Syntax(
code_action,
lexer="python",
theme="monokai",
word_wrap=True,
),
title="[bold]Executing this code:",
title_align="left",
box=box.HORIZONTALS,
),
level=LogLevel.INFO,
)
self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO)
observation = ""
is_final_answer = False
try:
Expand Down
85 changes: 0 additions & 85 deletions src/smolagents/logger.py

This file was deleted.

52 changes: 44 additions & 8 deletions src/smolagents/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, Union

from smolagents.models import ChatMessage, MessageRole
from smolagents.monitoring import AgentLogger
from smolagents.utils import AgentError, make_json_serializable


if TYPE_CHECKING:
from smolagents.logger import AgentLogger
from smolagents.models import ChatMessage


Expand Down Expand Up @@ -47,7 +49,7 @@ def to_messages(self, **kwargs) -> List[Dict[str, Any]]:

@dataclass
class ActionStep(MemoryStep):
model_input_messages: List[Dict[str, str]] | None = None
model_input_messages: List[Message] | None = None
tool_calls: List[ToolCall] | None = None
start_time: float | None = None
end_time: float | None = None
Expand Down Expand Up @@ -76,7 +78,7 @@ def dict(self):
"action_output": make_json_serializable(self.action_output),
}

def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Dict[str, Any]]:
def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Message]:
messages = []
if self.model_input_messages is not None and show_model_input_messages:
messages.append(Message(role=MessageRole.SYSTEM, content=self.model_input_messages))
Expand Down Expand Up @@ -142,17 +144,26 @@ def to_messages(self, summary_mode: bool = False, show_model_input_messages: boo

@dataclass
class PlanningStep(MemoryStep):
model_input_messages: List[Message]
model_output_message_facts: ChatMessage
facts: str
model_output_message_facts: ChatMessage
model_output_message_plan: ChatMessage
plan: str

def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]:
def to_messages(self, summary_mode: bool, **kwargs) -> List[Message]:
messages = []
messages.append(Message(role=MessageRole.ASSISTANT, content=f"[FACTS LIST]:\n{self.facts.strip()}"))
messages.append(
Message(
role=MessageRole.ASSISTANT, content=[{"type": "text", "text": f"[FACTS LIST]:\n{self.facts.strip()}"}]
)
)

if not summary_mode:
messages.append(Message(role=MessageRole.ASSISTANT, content=f"[PLAN]:\n{self.plan.strip()}"))
messages.append(
Message(
role=MessageRole.ASSISTANT, content=[{"type": "text", "text": f"[PLAN]:\n{self.plan.strip()}"}]
)
)
return messages


Expand All @@ -161,7 +172,7 @@ class TaskStep(MemoryStep):
task: str
task_images: List[str] | None = None

def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]:
def to_messages(self, summary_mode: bool = False, **kwargs) -> List[Message]:
content = [{"type": "text", "text": f"New task:\n{self.task}"}]
if self.task_images:
for image in self.task_images:
Expand All @@ -174,7 +185,7 @@ def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]:
class SystemPromptStep(MemoryStep):
system_prompt: str

def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]:
def to_messages(self, summary_mode: bool = False, **kwargs) -> List[Message]:
if summary_mode:
return []
return [Message(role=MessageRole.SYSTEM, content=[{"type": "text", "text": self.system_prompt.strip()}])]
Expand All @@ -196,5 +207,30 @@ def get_succinct_steps(self) -> list[dict]:
def get_full_steps(self) -> list[dict]:
return [step.dict() for step in self.steps]

def replay(self, logger: AgentLogger, detailed: bool = False):
"""Prints a pretty replay of the agent's steps.

Args:
logger (AgentLogger): The logger to print replay logs to.
detailed (bool, optional): If True, also displays the memory at each step. Defaults to False.
Careful: will increase log length exponentially. Use only for debugging.
"""
logger.console.log("Replaying the agent's steps:")
for step in self.steps:
if isinstance(step, SystemPromptStep) and detailed:
logger.log_markdown(title="System prompt", content=step.system_prompt)
elif isinstance(step, TaskStep):
logger.log_task(step.task, "", 2)
elif isinstance(step, ActionStep):
logger.log_rule(f"Step {step.step_number}")
if detailed:
logger.log_messages(step.model_input_messages)
logger.log_markdown(title="Agent output:", content=step.model_output)
elif isinstance(step, PlanningStep):
logger.log_rule("Planning step")
if detailed:
logger.log_messages(step.model_input_messages)
logger.log_markdown(title="Agent output:", content=step.facts + "\n" + step.plan)


__all__ = ["AgentMemory"]
Loading