Skip to content

Commit

Permalink
Add tests and refactor CLI (#892)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertvillanova authored Mar 7, 2025
1 parent ce25b0d commit e6bd39a
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 19 deletions.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ transformers = [
"transformers>=4.0.0,<4.49.0",
"smolagents[torch]",
]
vision = [
"helium",
"selenium",
]
all = [
"smolagents[audio,docker,e2b,gradio,litellm,mcp,openai,telemetry,transformers]",
"smolagents[audio,docker,e2b,gradio,litellm,mcp,openai,telemetry,transformers,vision]",
]
quality = [
"ruff>=0.9.0",
Expand Down
38 changes: 27 additions & 11 deletions src/smolagents/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
leopard_prompt = "How many seconds would it take for a leopard at full speed to run through Pont des Arts?"


def parse_arguments(description):
parser = argparse.ArgumentParser(description=description)
def parse_arguments():
parser = argparse.ArgumentParser(description="Run a CodeAgent with all specified parameters")
parser.add_argument(
"prompt",
type=str,
Expand Down Expand Up @@ -103,15 +103,21 @@ def load_model(model_type: str, model_id: str, api_base: str | None = None, api_
raise ValueError(f"Unsupported model type: {model_type}")


def main():
def main(
prompt: str,
tools: list[str],
model_type: str,
model_id: str,
api_base: str | None = None,
api_key: str | None = None,
imports: list[str] | None = None,
) -> None:
load_dotenv()

args = parse_arguments(description="Run a CodeAgent with all specified parameters")

model = load_model(args.model_type, args.model_id, args.api_base, args.api_key)
model = load_model(model_type, model_id, api_base=api_base, api_key=api_key)

available_tools = []
for tool_name in args.tools:
for tool_name in tools:
if "/" in tool_name:
available_tools.append(Tool.from_space(tool_name))
else:
Expand All @@ -120,11 +126,21 @@ def main():
else:
raise ValueError(f"Tool {tool_name} is not recognized either as a default tool or a Space.")

print(f"Running agent with these tools: {args.tools}")
agent = CodeAgent(tools=available_tools, model=model, additional_authorized_imports=args.imports)
print(f"Running agent with these tools: {tools}")
agent = CodeAgent(tools=available_tools, model=model, additional_authorized_imports=imports)

agent.run(args.prompt)
agent.run(prompt)


if __name__ == "__main__":
main()
args = parse_arguments()

main(
args.prompt,
args.tools,
args.model_type,
args.model_id,
api_base=args.api_base,
api_key=args.api_key,
imports=args.imports,
)
14 changes: 7 additions & 7 deletions src/smolagents/vision_web_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,24 +187,24 @@ def initialize_agent(model):
"""


def main():
def main(prompt: str, model_type: str, model_id: str) -> None:
# Load environment variables
load_dotenv()

# Parse command line arguments
args = parse_arguments()

# Initialize the model based on the provided arguments
model = load_model(args.model_type, args.model_id)
model = load_model(model_type, model_id)

global driver
driver = initialize_driver()
agent = initialize_agent(model)

# Run the agent with the provided prompt
agent.python_executor("from helium import *")
agent.run(args.prompt + helium_instructions)
agent.run(prompt + helium_instructions)


if __name__ == "__main__":
main()
# Parse command line arguments
args = parse_arguments()

main(args.prompt, args.model_type, args.model_id)
54 changes: 54 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from smolagents.cli import load_model
from smolagents.local_python_executor import LocalPythonExecutor
from smolagents.models import HfApiModel, LiteLLMModel, OpenAIServerModel, TransformersModel


Expand Down Expand Up @@ -52,3 +53,56 @@ def test_load_model_hf_api_model(set_env_vars):
def test_load_model_invalid_model_type():
with pytest.raises(ValueError, match="Unsupported model type: InvalidModel"):
load_model("InvalidModel", "test_model_id")


def test_cli_main(capsys):
with patch("smolagents.cli.load_model") as mock_load_model:
mock_load_model.return_value = "mock_model"
with patch("smolagents.cli.CodeAgent") as mock_code_agent:
from smolagents.cli import main

main("test_prompt", [], "HfApiModel", "test_model_id")
# load_model
assert len(mock_load_model.call_args_list) == 1
assert mock_load_model.call_args.args == ("HfApiModel", "test_model_id")
assert mock_load_model.call_args.kwargs == {"api_base": None, "api_key": None}
# CodeAgent
assert len(mock_code_agent.call_args_list) == 1
assert mock_code_agent.call_args.args == ()
assert mock_code_agent.call_args.kwargs == {
"tools": [],
"model": "mock_model",
"additional_authorized_imports": None,
}
# agent.run
assert len(mock_code_agent.return_value.run.call_args_list) == 1
assert mock_code_agent.return_value.run.call_args.args == ("test_prompt",)
# print
captured = capsys.readouterr()
assert "Running agent with these tools: []" in captured.out


def test_vision_web_browser_main():
with patch("smolagents.vision_web_browser.helium"):
with patch("smolagents.vision_web_browser.load_model") as mock_load_model:
mock_load_model.return_value = "mock_model"
with patch("smolagents.vision_web_browser.CodeAgent") as mock_code_agent:
from smolagents.vision_web_browser import helium_instructions, main

main("test_prompt", "HfApiModel", "test_model_id")
# load_model
assert len(mock_load_model.call_args_list) == 1
assert mock_load_model.call_args.args == ("HfApiModel", "test_model_id")
# CodeAgent
assert len(mock_code_agent.call_args_list) == 1
assert mock_code_agent.call_args.args == ()
assert len(mock_code_agent.call_args.kwargs["tools"]) == 4
assert mock_code_agent.call_args.kwargs["model"] == "mock_model"
assert mock_code_agent.call_args.kwargs["additional_authorized_imports"] == ["helium"]
# agent.python_executor
assert len(mock_code_agent.return_value.python_executor.call_args_list) == 1
assert mock_code_agent.return_value.python_executor.call_args.args == ("from helium import *",)
assert LocalPythonExecutor(["helium"])("from helium import *") == (None, "", False)
# agent.run
assert len(mock_code_agent.return_value.run.call_args_list) == 1
assert mock_code_agent.return_value.run.call_args.args == ("test_prompt" + helium_instructions,)

0 comments on commit e6bd39a

Please sign in to comment.