Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable custom final answer in CodeAgent #769

Merged
merged 8 commits into from
Feb 24, 2025
32 changes: 27 additions & 5 deletions docs/source/en/tutorials/secure_code_execution.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,24 @@ When working with AI agents that execute code, security is paramount. This guide
pip install 'smolagents[e2b]'
```

#### Running your agent in E2B
#### Running your agent in E2B: mono agents

Here's a complete example of running an agent in an E2B sandbox:
We provide a simple way to use an E2B Sandbox: simply add `use_e2b_executor=True` to the agent initialization, like:
```py
from smolagents import HfApiModel, CodeAgent

agent = CodeAgent(model=HfApiModel(), tools=[], use_e2b_executor=True)

agent.run("Can you give me the 100th Fibonacci number?")
````

However, this does not work with more complicated multi-agent setups.

#### Running your agent in E2B: multi-agents

To use multi-agents in an E2B sandbox, you need to run your agents completely from within E2B.

Here is how to do it:

```python
from e2b_code_interpreter import Sandbox
Expand Down Expand Up @@ -119,14 +134,21 @@ agent_code = """
import os
from smolagents import CodeAgent, HfApiModel

# Initialize the agent
# Initialize the agents
agent = CodeAgent(
model=HfApiModel(token=os.getenv("HF_TOKEN"), provider="together"),
tools=[]
tools=[],
name="coder_agent",
description="This agent takes care of your difficult algorithmic problems using code."
)

manager_agent = CodeAgent(
model=HfApiModel(token=os.getenv("HF_TOKEN"), provider="together"),
tools=[],
)

# Run the agent
response = agent.run("What's the 20th Fibonacci number?")
response = manager_agent.run("What's the 20th Fibonacci number?")
print(response)
"""

Expand Down
6 changes: 1 addition & 5 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,17 +1175,14 @@ def __init__(
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
)

all_tools = {**self.tools, **self.managed_agents}
if use_e2b_executor:
self.python_executor = E2BExecutor(
self.additional_authorized_imports,
list(all_tools.values()),
self.logger,
)
else:
self.python_executor = LocalPythonInterpreter(
self.additional_authorized_imports,
all_tools,
max_print_outputs_length=max_print_outputs_length,
)

Expand Down Expand Up @@ -1254,8 +1251,7 @@ def step(self, memory_step: ActionStep) -> Union[None, Any]:
is_final_answer = False
try:
output, execution_logs, is_final_answer = self.python_executor(
code_action,
self.state,
code_action, self.state, {**self.tools, **self.managed_agents}
)
execution_outputs_console = []
if len(execution_logs) > 0:
Expand Down
54 changes: 26 additions & 28 deletions src/smolagents/e2b_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import re
import textwrap
from io import BytesIO
from typing import Any, List, Tuple
from typing import Any, Dict, List, Tuple

from PIL import Image

Expand All @@ -37,7 +37,7 @@


class E2BExecutor:
def __init__(self, additional_imports: List[str], tools: List[Tool], logger):
def __init__(self, additional_imports: List[str], logger):
self.logger = logger
try:
from e2b_code_interpreter import Sandbox
Expand Down Expand Up @@ -67,30 +67,6 @@ def __init__(self, additional_imports: List[str], tools: List[Tool], logger):
else:
logger.log(f"Installation of {additional_imports} succeeded!", 0)

tool_codes = []
for tool in tools:
validate_tool_attributes(tool.__class__, check_imports=False)
tool_code = instance_to_source(tool, base_cls=Tool)
tool_code = tool_code.replace("from smolagents.tools import Tool", "")
tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n"
tool_codes.append(tool_code)

tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES])
tool_definition_code += textwrap.dedent(
"""
class Tool:
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def forward(self, *args, **kwargs):
pass # to be implemented in child class
"""
)
tool_definition_code += "\n\n".join(tool_codes)

tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
self.logger.log(tool_definition_execution.logs)

def run_code_raise_errors(self, code: str):
if self.final_answer_pattern.search(code) is not None:
self.final_answer = True
Expand All @@ -107,7 +83,7 @@ def run_code_raise_errors(self, code: str):
raise ValueError(logs)
return execution

def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]:
def __call__(self, code_action: str, additional_args: Dict[str, Any], tools: Dict[str, Tool]) -> Tuple[Any, Any]:
if len(additional_args) > 0:
# Pickle additional_args to server
import tempfile
Expand All @@ -128,7 +104,29 @@ def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]:
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
self.logger.log(execution_logs, 1)

execution = self.run_code_raise_errors(code_action)
tool_codes = []
for tool in tools.values():
validate_tool_attributes(tool.__class__, check_imports=False)
tool_code = instance_to_source(tool, base_cls=Tool)
tool_code = tool_code.replace("from smolagents.tools import Tool", "")
tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n"
tool_codes.append(tool_code)

tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES])
tool_definition_code += textwrap.dedent(
"""
class Tool:
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def forward(self, *args, **kwargs):
pass # to be implemented in child class
"""
)
tool_definition_code += "\n\n".join(tool_codes)

execution = self.run_code_raise_errors(tool_definition_code + code_action)
self.logger.log(execution.logs)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
if not execution.results:
return None, execution_logs, self.final_answer
Expand Down
22 changes: 11 additions & 11 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import numpy as np
import pandas as pd

from .tools import Tool
from .utils import BASE_BUILTIN_MODULES, truncate_content


Expand Down Expand Up @@ -1384,10 +1385,13 @@ def evaluate_python_code(
result = None
state["_print_outputs"] = PrintContainer()

def final_answer(value):
raise FinalAnswerException(value)
if "final_answer" in static_tools:
previous_final_answer = static_tools["final_answer"]

static_tools["final_answer"] = final_answer
def final_answer(value):
raise FinalAnswerException(previous_final_answer(value))

static_tools["final_answer"] = final_answer

try:
for node in expression.body:
Expand Down Expand Up @@ -1416,7 +1420,6 @@ class LocalPythonInterpreter:
def __init__(
self,
additional_authorized_imports: List[str],
tools: Dict,
max_print_outputs_length: Optional[int] = None,
):
self.custom_tools = {}
Expand All @@ -1426,18 +1429,15 @@ def __init__(
self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
self.additional_authorized_imports = additional_authorized_imports
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
# Add base trusted tools to list
self.static_tools = {
**tools,
**BASE_PYTHON_TOOLS.copy(),
}
# TODO: assert self.authorized imports are all installed locally

def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str, bool]:
def __call__(
self, code_action: str, additional_variables: Dict[str, Any], tools: Dict[str, Tool]
) -> Tuple[Any, str, bool]:
self.state.update(additional_variables)
output, is_final_answer = evaluate_python_code(
code_action,
static_tools=self.static_tools,
static_tools={**tools, **BASE_PYTHON_TOOLS.copy()},
custom_tools=self.custom_tools,
state=self.state,
authorized_imports=self.authorized_imports,
Expand Down
2 changes: 1 addition & 1 deletion src/smolagents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def validate_arguments(self):

if not set(signature.parameters.keys()) == set(self.inputs.keys()):
raise Exception(
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
f"In tool '{self.name}', 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albertvillanova this makes debugging easier

)

json_schema = _convert_type_hints_to_json_schema(self.forward, error_on_missing_type_hints=False)[
Expand Down
35 changes: 35 additions & 0 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,41 @@ def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
agent.run("Test request")
assert "secret\\\\" in repr(capture.get())

def test_change_tools_after_init(self):
from smolagents import Tool, tool

@tool
def fake_tool_1() -> str:
"""Fake tool"""
return "1"

@tool
def fake_tool_2() -> str:
"""Fake tool"""
return "2"

def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
return ChatMessage(role="assistant", content="Code:\n```py\nfinal_answer(fake_tool_1())\n```")

agent = CodeAgent(tools=[fake_tool_1], model=fake_code_model)

class ModifiedFinalAnswerTool(Tool):
name = "final_answer"
description = "Provides a final answer to the given problem."
inputs = {"answer": {"type": "any", "description": "The final function that solves the problem"}}
output_type = "string"

def forward(self, answer) -> str:
return answer + "FLAG"

agent.tools["final_answer"] = ModifiedFinalAnswerTool()
agent.tools["fake_tool_1"] = fake_tool_2

answer = agent.run("Fake task.")
assert answer == "2FLAG"

agent = CodeAgent(tools=[], model=fake_code_model)


class MultiAgentsTests(unittest.TestCase):
def test_multiagents_save(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_e2b_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_e2b_executor_instantiation(self):
with patch("e2b_code_interpreter.Sandbox") as mock_sandbox:
mock_sandbox.return_value.commands.run.return_value.error = None
mock_sandbox.return_value.run_code.return_value.error = None
executor = E2BExecutor(additional_imports=[], tools=[], logger=logger)
executor = E2BExecutor(additional_imports=[], logger=logger)
assert isinstance(executor, E2BExecutor)
assert executor.logger == logger
assert executor.final_answer is False
Expand Down