Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable custom final answer in CodeAgent #769

Merged
merged 8 commits into from
Feb 24, 2025
33 changes: 28 additions & 5 deletions docs/source/en/tutorials/secure_code_execution.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,24 @@ When working with AI agents that execute code, security is paramount. This guide
pip install 'smolagents[e2b]'
```

#### Running your agent in E2B
#### Running your agent in E2B: mono agents

Here's a complete example of running an agent in an E2B sandbox:
We provide a simple way to use an E2B Sandbox: simply add `use_e2b_executor=True` to the agent initialization, like:
```py
from smolagents import HfApiModel, CodeAgent

agent = CodeAgent(model=HfApiModel(), tools=[], use_e2b_executor=True)

agent.run("Can you give me the 100th Fibonacci number?")
```

However, this does not work (yet) with more complicated multi-agent setups.

#### Running your agent in E2B: multi-agents

To use multi-agents in an E2B sandbox, you need to run your agents completely from within E2B.

Here is how to do it:

```python
from e2b_code_interpreter import Sandbox
Expand Down Expand Up @@ -119,14 +134,22 @@ agent_code = """
import os
from smolagents import CodeAgent, HfApiModel

# Initialize the agent
# Initialize the agents
agent = CodeAgent(
model=HfApiModel(token=os.getenv("HF_TOKEN"), provider="together"),
tools=[]
tools=[],
name="coder_agent",
description="This agent takes care of your difficult algorithmic problems using code."
)

manager_agent = CodeAgent(
model=HfApiModel(token=os.getenv("HF_TOKEN"), provider="together"),
tools=[],
managed_agents=[agent],
)

# Run the agent
response = agent.run("What's the 20th Fibonacci number?")
response = manager_agent.run("What's the 20th Fibonacci number?")
print(response)
"""

Expand Down
12 changes: 4 additions & 8 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,11 @@ def run(
level=LogLevel.INFO,
title=self.name if hasattr(self, "name") else None,
)

self.memory.steps.append(TaskStep(task=self.task, task_images=images))

if getattr(self, "python_executor", None):
self.python_executor.update_tools({**self.tools, **self.managed_agents})
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albertvillanova this is where the tool update happens.


if stream:
# The steps are returned as they are executed through a generator to iterate on.
return self._run(task=self.task, max_steps=max_steps, images=images)
Expand Down Expand Up @@ -1175,17 +1177,14 @@ def __init__(
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
)

all_tools = {**self.tools, **self.managed_agents}
if use_e2b_executor:
self.python_executor = E2BExecutor(
self.additional_authorized_imports,
list(all_tools.values()),
self.logger,
)
else:
self.python_executor = LocalPythonInterpreter(
self.additional_authorized_imports,
all_tools,
max_print_outputs_length=max_print_outputs_length,
)

Expand Down Expand Up @@ -1253,10 +1252,7 @@ def step(self, memory_step: ActionStep) -> Union[None, Any]:
self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO)
is_final_answer = False
try:
output, execution_logs, is_final_answer = self.python_executor(
code_action,
self.state,
)
output, execution_logs, is_final_answer = self.python_executor(code_action, self.state)
execution_outputs_console = []
if len(execution_logs) > 0:
execution_outputs_console += [
Expand Down
46 changes: 24 additions & 22 deletions src/smolagents/e2b_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import re
import textwrap
from io import BytesIO
from typing import Any, List, Tuple
from typing import Any, Dict, List, Tuple

from PIL import Image

Expand All @@ -37,7 +37,7 @@


class E2BExecutor:
def __init__(self, additional_imports: List[str], tools: List[Tool], logger):
def __init__(self, additional_imports: List[str], logger):
self.logger = logger
try:
from e2b_code_interpreter import Sandbox
Expand Down Expand Up @@ -67,8 +67,25 @@ def __init__(self, additional_imports: List[str], tools: List[Tool], logger):
else:
logger.log(f"Installation of {additional_imports} succeeded!", 0)

def run_code_raise_errors(self, code: str):
if self.final_answer_pattern.search(code) is not None:
self.final_answer = True
execution = self.sbx.run_code(
code,
)
if execution.error:
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
logs = execution_logs
logs += "Executing code yielded an error:"
logs += execution.error.name
logs += execution.error.value
logs += execution.error.traceback
raise ValueError(logs)
return execution

def update_tools(self, tools: Dict[str, Tool]):
tool_codes = []
for tool in tools:
for tool in tools.values():
validate_tool_attributes(tool.__class__, check_imports=False)
tool_code = instance_to_source(tool, base_cls=Tool)
tool_code = tool_code.replace("from smolagents.tools import Tool", "")
Expand All @@ -88,26 +105,10 @@ def forward(self, *args, **kwargs):
)
tool_definition_code += "\n\n".join(tool_codes)

tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
self.logger.log(tool_definition_execution.logs)

def run_code_raise_errors(self, code: str):
if self.final_answer_pattern.search(code) is not None:
self.final_answer = True
execution = self.sbx.run_code(
code,
)
if execution.error:
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
logs = execution_logs
logs += "Executing code yielded an error:"
logs += execution.error.name
logs += execution.error.value
logs += execution.error.traceback
raise ValueError(logs)
return execution
execution = self.run_code_raise_errors(tool_definition_code)
self.logger.log(execution.logs)

def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]:
def __call__(self, code_action: str, additional_args: Dict[str, Any]) -> Tuple[Any, Any]:
if len(additional_args) > 0:
# Pickle additional_args to server
import tempfile
Expand All @@ -129,6 +130,7 @@ def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]:
self.logger.log(execution_logs, 1)

execution = self.run_code_raise_errors(code_action)
self.logger.log(execution.logs)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
if not execution.results:
return None, execution_logs, self.final_answer
Expand Down
22 changes: 12 additions & 10 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import numpy as np
import pandas as pd

from .tools import Tool
from .utils import BASE_BUILTIN_MODULES, truncate_content


Expand Down Expand Up @@ -1384,10 +1385,13 @@ def evaluate_python_code(
result = None
state["_print_outputs"] = PrintContainer()

def final_answer(value):
raise FinalAnswerException(value)
if "final_answer" in static_tools:
previous_final_answer = static_tools["final_answer"]

static_tools["final_answer"] = final_answer
def final_answer(value):
raise FinalAnswerException(previous_final_answer(value))

static_tools["final_answer"] = final_answer

try:
for node in expression.body:
Expand Down Expand Up @@ -1416,7 +1420,6 @@ class LocalPythonInterpreter:
def __init__(
self,
additional_authorized_imports: List[str],
tools: Dict,
max_print_outputs_length: Optional[int] = None,
):
self.custom_tools = {}
Expand All @@ -1426,14 +1429,10 @@ def __init__(
self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
self.additional_authorized_imports = additional_authorized_imports
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
# Add base trusted tools to list
self.static_tools = {
**tools,
**BASE_PYTHON_TOOLS.copy(),
}
# TODO: assert self.authorized imports are all installed locally
self.static_tools = None

def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str, bool]:
def __call__(self, code_action: str, additional_variables: Dict[str, Any]) -> Tuple[Any, str, bool]:
self.state.update(additional_variables)
output, is_final_answer = evaluate_python_code(
code_action,
Expand All @@ -1446,5 +1445,8 @@ def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, s
logs = str(self.state["_print_outputs"])
return output, logs, is_final_answer

def update_tools(self, tools: Dict[str, Tool]):
self.static_tools = {**tools, **BASE_PYTHON_TOOLS.copy()}


__all__ = ["evaluate_python_code", "LocalPythonInterpreter"]
2 changes: 1 addition & 1 deletion src/smolagents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def validate_arguments(self):

if not set(signature.parameters.keys()) == set(self.inputs.keys()):
raise Exception(
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
f"In tool '{self.name}', 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albertvillanova this makes debugging easier

)

json_schema = _convert_type_hints_to_json_schema(self.forward, error_on_missing_type_hints=False)[
Expand Down
37 changes: 36 additions & 1 deletion tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def test_code_agent_code_errors_show_offending_line_and_error(self):
assert "ValueError" in str(agent.memory.steps)

def test_code_agent_code_error_saves_previous_print_outputs(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_error)
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_error, verbosity_level=10)
agent.run("What is 2 multiplied by 3.6452?")
assert "Flag!" in str(agent.memory.steps[1].observations)

Expand Down Expand Up @@ -800,6 +800,41 @@ def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
agent.run("Test request")
assert "secret\\\\" in repr(capture.get())

def test_change_tools_after_init(self):
from smolagents import Tool, tool

@tool
def fake_tool_1() -> str:
"""Fake tool"""
return "1"

@tool
def fake_tool_2() -> str:
"""Fake tool"""
return "2"

def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
return ChatMessage(role="assistant", content="Code:\n```py\nfinal_answer(fake_tool_1())\n```")

agent = CodeAgent(tools=[fake_tool_1], model=fake_code_model)

class ModifiedFinalAnswerTool(Tool):
name = "final_answer"
description = "Provides a final answer to the given problem."
inputs = {"answer": {"type": "any", "description": "The final function that solves the problem"}}
output_type = "string"

def forward(self, answer) -> str:
return answer + "FLAG"

agent.tools["final_answer"] = ModifiedFinalAnswerTool()
agent.tools["fake_tool_1"] = fake_tool_2

answer = agent.run("Fake task.")
assert answer == "2FLAG"

agent = CodeAgent(tools=[], model=fake_code_model)


class MultiAgentsTests(unittest.TestCase):
def test_multiagents_save(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_e2b_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_e2b_executor_instantiation(self):
with patch("e2b_code_interpreter.Sandbox") as mock_sandbox:
mock_sandbox.return_value.commands.run.return_value.error = None
mock_sandbox.return_value.run_code.return_value.error = None
executor = E2BExecutor(additional_imports=[], tools=[], logger=logger)
executor = E2BExecutor(additional_imports=[], logger=logger)
assert isinstance(executor, E2BExecutor)
assert executor.logger == logger
assert executor.final_answer is False
Expand Down