Skip to content

Commit

Permalink
Enable custom final answer in CodeAgent (#769)
Browse files Browse the repository at this point in the history
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
  • Loading branch information
aymeric-roucher and albertvillanova authored Feb 24, 2025
1 parent 44f4336 commit 99102f1
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 48 deletions.
33 changes: 28 additions & 5 deletions docs/source/en/tutorials/secure_code_execution.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,24 @@ When working with AI agents that execute code, security is paramount. This guide
pip install 'smolagents[e2b]'
```

#### Running your agent in E2B
#### Running your agent in E2B: mono agents

Here's a complete example of running an agent in an E2B sandbox:
We provide a simple way to use an E2B Sandbox: simply add `use_e2b_executor=True` to the agent initialization, like:
```py
from smolagents import HfApiModel, CodeAgent

agent = CodeAgent(model=HfApiModel(), tools=[], use_e2b_executor=True)

agent.run("Can you give me the 100th Fibonacci number?")
```

However, this does not work (yet) with more complicated multi-agent setups.

#### Running your agent in E2B: multi-agents

To use multi-agents in an E2B sandbox, you need to run your agents completely from within E2B.

Here is how to do it:

```python
from e2b_code_interpreter import Sandbox
Expand Down Expand Up @@ -119,14 +134,22 @@ agent_code = """
import os
from smolagents import CodeAgent, HfApiModel
# Initialize the agent
# Initialize the agents
agent = CodeAgent(
model=HfApiModel(token=os.getenv("HF_TOKEN"), provider="together"),
tools=[]
tools=[],
name="coder_agent",
description="This agent takes care of your difficult algorithmic problems using code."
)
manager_agent = CodeAgent(
model=HfApiModel(token=os.getenv("HF_TOKEN"), provider="together"),
tools=[],
managed_agents=[agent],
)
# Run the agent
response = agent.run("What's the 20th Fibonacci number?")
response = manager_agent.run("What's the 20th Fibonacci number?")
print(response)
"""

Expand Down
12 changes: 4 additions & 8 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,11 @@ def run(
level=LogLevel.INFO,
title=self.name if hasattr(self, "name") else None,
)

self.memory.steps.append(TaskStep(task=self.task, task_images=images))

if getattr(self, "python_executor", None):
self.python_executor.update_tools({**self.tools, **self.managed_agents})

if stream:
# The steps are returned as they are executed through a generator to iterate on.
return self._run(task=self.task, max_steps=max_steps, images=images)
Expand Down Expand Up @@ -1175,17 +1177,14 @@ def __init__(
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
)

all_tools = {**self.tools, **self.managed_agents}
if use_e2b_executor:
self.python_executor = E2BExecutor(
self.additional_authorized_imports,
list(all_tools.values()),
self.logger,
)
else:
self.python_executor = LocalPythonInterpreter(
self.additional_authorized_imports,
all_tools,
max_print_outputs_length=max_print_outputs_length,
)

Expand Down Expand Up @@ -1253,10 +1252,7 @@ def step(self, memory_step: ActionStep) -> Union[None, Any]:
self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO)
is_final_answer = False
try:
output, execution_logs, is_final_answer = self.python_executor(
code_action,
self.state,
)
output, execution_logs, is_final_answer = self.python_executor(code_action, self.state)
execution_outputs_console = []
if len(execution_logs) > 0:
execution_outputs_console += [
Expand Down
46 changes: 24 additions & 22 deletions src/smolagents/e2b_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import re
import textwrap
from io import BytesIO
from typing import Any, List, Tuple
from typing import Any, Dict, List, Tuple

from PIL import Image

Expand All @@ -37,7 +37,7 @@


class E2BExecutor:
def __init__(self, additional_imports: List[str], tools: List[Tool], logger):
def __init__(self, additional_imports: List[str], logger):
self.logger = logger
try:
from e2b_code_interpreter import Sandbox
Expand Down Expand Up @@ -67,8 +67,25 @@ def __init__(self, additional_imports: List[str], tools: List[Tool], logger):
else:
logger.log(f"Installation of {additional_imports} succeeded!", 0)

def run_code_raise_errors(self, code: str):
if self.final_answer_pattern.search(code) is not None:
self.final_answer = True
execution = self.sbx.run_code(
code,
)
if execution.error:
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
logs = execution_logs
logs += "Executing code yielded an error:"
logs += execution.error.name
logs += execution.error.value
logs += execution.error.traceback
raise ValueError(logs)
return execution

def update_tools(self, tools: Dict[str, Tool]):
tool_codes = []
for tool in tools:
for tool in tools.values():
validate_tool_attributes(tool.__class__, check_imports=False)
tool_code = instance_to_source(tool, base_cls=Tool)
tool_code = tool_code.replace("from smolagents.tools import Tool", "")
Expand All @@ -88,26 +105,10 @@ def forward(self, *args, **kwargs):
)
tool_definition_code += "\n\n".join(tool_codes)

tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
self.logger.log(tool_definition_execution.logs)

def run_code_raise_errors(self, code: str):
if self.final_answer_pattern.search(code) is not None:
self.final_answer = True
execution = self.sbx.run_code(
code,
)
if execution.error:
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
logs = execution_logs
logs += "Executing code yielded an error:"
logs += execution.error.name
logs += execution.error.value
logs += execution.error.traceback
raise ValueError(logs)
return execution
execution = self.run_code_raise_errors(tool_definition_code)
self.logger.log(execution.logs)

def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]:
def __call__(self, code_action: str, additional_args: Dict[str, Any]) -> Tuple[Any, Any]:
if len(additional_args) > 0:
# Pickle additional_args to server
import tempfile
Expand All @@ -129,6 +130,7 @@ def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]:
self.logger.log(execution_logs, 1)

execution = self.run_code_raise_errors(code_action)
self.logger.log(execution.logs)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
if not execution.results:
return None, execution_logs, self.final_answer
Expand Down
22 changes: 12 additions & 10 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import numpy as np
import pandas as pd

from .tools import Tool
from .utils import BASE_BUILTIN_MODULES, truncate_content


Expand Down Expand Up @@ -1384,10 +1385,13 @@ def evaluate_python_code(
result = None
state["_print_outputs"] = PrintContainer()

def final_answer(value):
raise FinalAnswerException(value)
if "final_answer" in static_tools:
previous_final_answer = static_tools["final_answer"]

static_tools["final_answer"] = final_answer
def final_answer(value):
raise FinalAnswerException(previous_final_answer(value))

static_tools["final_answer"] = final_answer

try:
for node in expression.body:
Expand Down Expand Up @@ -1416,7 +1420,6 @@ class LocalPythonInterpreter:
def __init__(
self,
additional_authorized_imports: List[str],
tools: Dict,
max_print_outputs_length: Optional[int] = None,
):
self.custom_tools = {}
Expand All @@ -1426,14 +1429,10 @@ def __init__(
self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
self.additional_authorized_imports = additional_authorized_imports
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
# Add base trusted tools to list
self.static_tools = {
**tools,
**BASE_PYTHON_TOOLS.copy(),
}
# TODO: assert self.authorized imports are all installed locally
self.static_tools = None

def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str, bool]:
def __call__(self, code_action: str, additional_variables: Dict[str, Any]) -> Tuple[Any, str, bool]:
self.state.update(additional_variables)
output, is_final_answer = evaluate_python_code(
code_action,
Expand All @@ -1446,5 +1445,8 @@ def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, s
logs = str(self.state["_print_outputs"])
return output, logs, is_final_answer

def update_tools(self, tools: Dict[str, Tool]):
self.static_tools = {**tools, **BASE_PYTHON_TOOLS.copy()}


__all__ = ["evaluate_python_code", "LocalPythonInterpreter"]
2 changes: 1 addition & 1 deletion src/smolagents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def validate_arguments(self):

if not set(signature.parameters.keys()) == set(self.inputs.keys()):
raise Exception(
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
f"In tool '{self.name}', 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
)

json_schema = _convert_type_hints_to_json_schema(self.forward, error_on_missing_type_hints=False)[
Expand Down
37 changes: 36 additions & 1 deletion tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def test_code_agent_code_errors_show_offending_line_and_error(self):
assert "ValueError" in str(agent.memory.steps)

def test_code_agent_code_error_saves_previous_print_outputs(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_error)
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_error, verbosity_level=10)
agent.run("What is 2 multiplied by 3.6452?")
assert "Flag!" in str(agent.memory.steps[1].observations)

Expand Down Expand Up @@ -800,6 +800,41 @@ def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
agent.run("Test request")
assert "secret\\\\" in repr(capture.get())

def test_change_tools_after_init(self):
from smolagents import Tool, tool

@tool
def fake_tool_1() -> str:
"""Fake tool"""
return "1"

@tool
def fake_tool_2() -> str:
"""Fake tool"""
return "2"

def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
return ChatMessage(role="assistant", content="Code:\n```py\nfinal_answer(fake_tool_1())\n```")

agent = CodeAgent(tools=[fake_tool_1], model=fake_code_model)

class ModifiedFinalAnswerTool(Tool):
name = "final_answer"
description = "Provides a final answer to the given problem."
inputs = {"answer": {"type": "any", "description": "The final function that solves the problem"}}
output_type = "string"

def forward(self, answer) -> str:
return answer + "FLAG"

agent.tools["final_answer"] = ModifiedFinalAnswerTool()
agent.tools["fake_tool_1"] = fake_tool_2

answer = agent.run("Fake task.")
assert answer == "2FLAG"

agent = CodeAgent(tools=[], model=fake_code_model)


class MultiAgentsTests(unittest.TestCase):
def test_multiagents_save(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_e2b_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_e2b_executor_instantiation(self):
with patch("e2b_code_interpreter.Sandbox") as mock_sandbox:
mock_sandbox.return_value.commands.run.return_value.error = None
mock_sandbox.return_value.run_code.return_value.error = None
executor = E2BExecutor(additional_imports=[], tools=[], logger=logger)
executor = E2BExecutor(additional_imports=[], logger=logger)
assert isinstance(executor, E2BExecutor)
assert executor.logger == logger
assert executor.final_answer is False
Expand Down

0 comments on commit 99102f1

Please sign in to comment.