diff --git a/docs/source/en/guided_tour.mdx b/docs/source/en/guided_tour.mdx index 5eca7fc21..d535ba11f 100644 --- a/docs/source/en/guided_tour.mdx +++ b/docs/source/en/guided_tour.mdx @@ -349,6 +349,47 @@ Out[20]: 'ByteDance/AnimateDiff-Lightning' > [!TIP] > Read more on tools in the [dedicated tutorial](./tutorials/tools#what-is-a-tool-and-how-to-build-one). + +### Tools from methods +Sometimes, you need to preserve state or perform an initialization step, such as handling tokens, setting up API clients, or sharing resources among tools. In these cases, you can encapsulate a tool within a class method. + +You still use the @tool decorator, but you apply it directly to a method. Let's modify our previous example to handle authentication using a HuggingFace token: +```python +from smolagents import tool + +class DownloadTools: + def __init__(self, hf_token: str): + self.hf_token = hf_token + + @tool + def model_download_tool(self, task: str) -> str: + """ + Returns the most downloaded model of a given task on the Hugging Face Hub. + Uses the authentication token from initialization. + + Args: + task: The task for which to get the most downloaded model. + """ + most_downloaded_model = next(iter( + list_models( + filter=task, + sort="downloads", + direction=-1, + token=self.hf_token # Pass the token from init + ) + )) + return most_downloaded_model.id + +# Usage: +dl_tools = DownloadTools(hf_token="hf_xxxxxxxx") +agent = CodeAgent(tools=[dl_tools.model_download_tool], model=HfApiModel()) +``` + +> [!TIP] +> You can create multiple tools in a single class by decorating more methods with @tool. +> You can also mix class-based and function-based tools in the same agent. Go with whichever approach suits your needs best! + + ## Multi-agents Multi-agent systems have been introduced with Microsoft's framework [Autogen](https://huggingface.co/papers/2308.08155). diff --git a/docs/source/en/tutorials/tools.mdx b/docs/source/en/tutorials/tools.mdx index d9da1e94f..8e7a19d92 100644 --- a/docs/source/en/tutorials/tools.mdx +++ b/docs/source/en/tutorials/tools.mdx @@ -76,7 +76,7 @@ The custom tool subclasses [`Tool`] to inherit useful methods. The child class a And that's all it needs to be used in an agent! -There's another way to build a tool. In the [guided_tour](../guided_tour), we implemented a tool using the `@tool` decorator. The [`tool`] decorator is the recommended way to define simple tools, but sometimes you need more than this: using several methods in a class for more clarity, or using additional class attributes. +There's another way to build a tool. In the [guided_tour](../guided_tour), we implemented a tool using the `@tool` decorator on a function The [`tool`] decorator is the recommended way to define simple tools, but sometimes you need more than this: using several methods in a class for more clarity, or using additional class attributes. In this case, you can build your tool by subclassing [`Tool`] as described above. diff --git a/src/smolagents/_function_type_hints_utils.py b/src/smolagents/_function_type_hints_utils.py index 13c6a2548..a528ce6d0 100644 --- a/src/smolagents/_function_type_hints_utils.py +++ b/src/smolagents/_function_type_hints_utils.py @@ -83,7 +83,11 @@ class DocstringParsingException(Exception): """Exception raised for errors in parsing docstrings to generate JSON schemas""" -def get_json_schema(func: Callable) -> Dict: +class ToolCreationException(Exception): + """Exception raised for errors in creating a tool from a function or method""" + + +def get_json_schema(func: Callable, skip_self_cls: bool = False) -> Dict: """ This function generates a JSON schema for a given function, based on its docstring and type hints. This is mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of @@ -96,6 +100,7 @@ def get_json_schema(func: Callable) -> Dict: Args: func: The function to generate a JSON schema for. + skip_self_cls: Whether to skip the `self` and `cls` arguments in the schema. This is useful for bound methods in classes. Returns: A dictionary containing the JSON schema for the function. @@ -197,8 +202,7 @@ def get_json_schema(func: Callable) -> Dict: ) doc = doc.strip() main_doc, param_descriptions, return_doc = _parse_google_format_docstring(doc) - - json_schema = _convert_type_hints_to_json_schema(func) + json_schema = _convert_type_hints_to_json_schema(func, skip_self_cls=skip_self_cls) if (return_dict := json_schema["properties"].pop("return", None)) is not None: if return_doc is not None: # We allow a missing return docstring since most templates ignore it return_dict["description"] = return_doc @@ -273,16 +277,19 @@ def _parse_google_format_docstring( return description, args_dict, returns -def _convert_type_hints_to_json_schema(func: Callable, error_on_missing_type_hints: bool = True) -> Dict: +def _convert_type_hints_to_json_schema( + func: Callable, error_on_missing_type_hints: bool = True, skip_self_cls: bool = False +) -> Dict: type_hints = get_type_hints(func) signature = inspect.signature(func) - properties = {} for param_name, param_type in type_hints.items(): properties[param_name] = _parse_type_hint(param_type) required = [] for param_name, param in signature.parameters.items(): + if skip_self_cls and param_name in ["self", "cls"]: + continue if param.annotation == inspect.Parameter.empty and error_on_missing_type_hints: raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") if param_name not in properties: diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index ab3568e92..dab05ffd4 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -38,6 +38,7 @@ from huggingface_hub.utils import is_torch_available from ._function_type_hints_utils import ( + ToolCreationException, TypeHintParsingException, _convert_type_hints_to_json_schema, get_imports, @@ -250,9 +251,11 @@ def replacement(match): "SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper", + "MethodToolWrapper", ]: raise ValueError( - "Cannot save objects created with from_space, from_langchain or from_gradio, as this would create errors." + "Cannot save objects created with from_space, from_langchain, from_gradio, " + "or class methods decorated with @tool as this would create errors." ) validate_tool_attributes(self.__class__) @@ -838,9 +841,28 @@ def tool(tool_function: Callable) -> Tool: Converts a function into an instance of a Tool subclass. Args: - tool_function: Your function. Should have type hints for each input and a type hint for the output. + tool_function: Your function or class method. Should have type hints for each input and a type hint for the output. Should also have a docstring description including an 'Args:' part where each argument is described. """ + original_signature = inspect.signature(tool_function) + params = list(original_signature.parameters.values()) + + is_method = ( # inspect.ismethod won't work here because the decorator is applied before the method is bound + params + and params[0].name in ("self", "cls") + and hasattr(tool_function, "__qualname__") + and "." in tool_function.__qualname__ + ) + + if is_method: + return _tool_from_method(tool_function, original_signature) + elif inspect.isfunction(tool_function): + return _tool_from_function(tool_function, original_signature) + else: + raise ToolCreationException("Tool decorator can only be used on functions or methods.") + + +def _tool_from_function(tool_function: Callable, original_signature: inspect.Signature) -> Tool: tool_json_schema = get_json_schema(tool_function)["function"] if "return" not in tool_json_schema: raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!") @@ -868,7 +890,6 @@ def __init__( output_type=tool_json_schema["return"]["type"], function=tool_function, ) - original_signature = inspect.signature(tool_function) new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY)] + list( original_signature.parameters.values() ) @@ -877,6 +898,34 @@ def __init__( return simple_tool +def _tool_from_method(tool_function: Callable, original_signature: inspect.Signature) -> Tool: + tool_json_schema = get_json_schema(tool_function, skip_self_cls=True)["function"] + if "return" not in tool_json_schema: + raise TypeHintParsingException("Tool return type not found: make sure your method has a return type hint!") + + class MethodToolWrapper(Tool): + def __init__(self): + self.name = tool_json_schema["name"] + self.description = tool_json_schema["description"] + self.inputs = tool_json_schema["parameters"]["properties"] + self.output_type = tool_json_schema["return"]["type"] + self.is_initialized = True + # Temp forward function; will be replaced with a bound instance in __get__. (this is for sig validation) + self.forward = lambda: 0 + new_parameters = [v for v in original_signature.parameters.values() if v.name not in ["self", "cls"]] + modified_sig = original_signature.replace(parameters=new_parameters) + self.forward.__signature__ = modified_sig + + # When the decorator is first used, it's on an unbound class method. + # So we have to use a descriptor here to bind the instance to the forward method when accessed. + def __get__(self, instance, owner): + func = getattr(tool_function, "__func__", tool_function) + self.forward = lambda *args, **kwargs: func(instance, *args, **kwargs) + return self + + return MethodToolWrapper() + + class PipelineTool(Tool): """ A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will diff --git a/tests/test_all_docs.py b/tests/test_all_docs.py index 7dcbf5838..6df8dc76c 100644 --- a/tests/test_all_docs.py +++ b/tests/test_all_docs.py @@ -94,7 +94,7 @@ def setup_class(cls): load_dotenv() - cls.md_files = list(cls.docs_dir.rglob("*.md")) + cls.md_files = list(cls.docs_dir.rglob("*.mdx")) if not cls.md_files: raise ValueError(f"No markdown files found in {cls.docs_dir}") diff --git a/tests/test_tools.py b/tests/test_tools.py index cb8a8eeaa..90c26e02d 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -133,6 +133,58 @@ def forward(self, task: str) -> str: tool = HFModelDownloadsTool() assert list(tool.inputs.keys())[0] == "task" + def test_method_tool_init_decorator(self): + class ClassWithToolMethod: + d = 3 + + def __init__(self, c: int): + self.c = c + + @tool + def sumfunc_instance_method(self, a: int, b: int) -> int: + """Sum function as instance method + + Args: + a: The first argument + b: The second one + """ + return a + b + self.c + + @classmethod + @tool + def sumfunc_class_method(cls, a: int, b: int) -> int: + """Sum function as class method + + Args: + a: The first argument + b: The second one + """ + return a + b + cls.d + + @staticmethod + @tool + def sumfunc_static_method(a: int, b: int) -> int: + """Sum function as static method + + Args: + a: The first argument + b: The second one + """ + return a + b + 3 + + initialized_tool = ClassWithToolMethod(c=3) + for method_type in ("instance", "class", "static"): + cmethod = getattr(initialized_tool, f"sumfunc_{method_type}_method") + assert cmethod.output_type == "integer" + assert cmethod.description == f"Sum function as {method_type} method" + assert cmethod.name == f"sumfunc_{method_type}_method" + assert cmethod.inputs["a"]["description"] == "The first argument" + assert cmethod.inputs["b"]["description"] == "The second one" + correct_answer = initialized_tool.sumfunc_class_method(a=7, b=19) + assert correct_answer == 29 + assert cmethod.forward(a=7, b=19) == correct_answer + assert cmethod.forward(7, 19) == correct_answer + def test_tool_init_decorator_raises_issues(self): with pytest.raises(Exception) as e: @@ -163,6 +215,38 @@ def coolfunc(a: str, b: int) -> int: assert coolfunc.output_type == "number" assert "docstring has no description for the argument" in str(e) + def test_method_tool_init_decorator_raises_issues(self): + with pytest.raises(Exception) as e: + + class ToolMethod: + @tool + def cool_method(self, a: str, b: int): + """Cool method + + Args: + a: The first argument + b: The second one + """ + return a + b + + assert ToolMethod().cool_method.output_type == "number" + assert "Tool return type not found" in str(e) + + with pytest.raises(Exception) as e: + + class ToolMethod: + @tool + def cool_method(self, a: str, b: int): + """Cool function + + Args: + a: The first argument + """ + return a + b + + assert ToolMethod().cool_method.output_type == "number" + assert "docstring has no description for the argument" in str(e) + def test_saving_tool_raises_error_imports_outside_function(self): with pytest.raises(Exception) as e: import numpy as np @@ -331,6 +415,33 @@ def get_weather(location: str, celsius: Optional[bool] = False) -> str: assert get_weather.inputs["celsius"]["nullable"] assert "nullable" not in get_weather.inputs["location"] + def test_method_tool_from_decorator_optional_args(self): + class ToolMethod: + def __init__(self, desc: str): + self.desc = desc + + @tool + def get_weather(self, location: str, celsius: Optional[bool] = False) -> str: + """ + Get weather in the next days at given location but only if it's Paris. + + Args: + location: the location + celsius: the temperature type + """ + if location != "Paris": + raise Exception("no") + return f"The weather in {location} is {self.desc} with sunny skies and temperatures above 30°C" + + mytool = ToolMethod(desc="lovely") + assert "nullable" in mytool.get_weather.inputs["celsius"] + assert mytool.get_weather.inputs["celsius"]["nullable"] + assert "nullable" not in mytool.get_weather.inputs["location"] + assert ( + mytool.get_weather("Paris") + == "The weather in Paris is lovely with sunny skies and temperatures above 30°C" + ) + def test_tool_mismatching_nullable_args_raises_error(self): with pytest.raises(Exception) as e: @@ -407,6 +518,21 @@ def get_weather(location: str, celsius: bool = False) -> str: assert get_weather.inputs["celsius"]["nullable"] + def test_method_tool_default_parameters_is_nullable(self): + class ToolMethod: + @tool + def get_weather(self, location: str, celsius: bool = False) -> str: + """ + Get weather in the next days at given location. + + Args: + location: The location to get the weather for. + celsius: is the temperature given in celsius? + """ + return "The weather is UNGODLY with torrential rains and temperatures below -10°C" + + assert ToolMethod().get_weather.inputs["celsius"]["nullable"] + def test_tool_supports_any_none(self): @tool def get_weather(location: Any) -> None: @@ -423,6 +549,24 @@ def get_weather(location: Any) -> None: assert get_weather.inputs["location"]["type"] == "any" assert get_weather.output_type == "null" + def test_method_tool_errors_on_save(self): + class ToolMethod: + @tool + def get_weather(self, location: Any) -> None: + """ + Get weather in the next days at given location. + + Args: + location: The location to get the weather for. + """ + return + + mytool = ToolMethod() + with pytest.raises(Exception) as e: + with tempfile.TemporaryDirectory() as tmp_dir: + mytool.get_weather.save(tmp_dir) + assert "Cannot save objects created with" in str(e) + def test_tool_supports_array(self): @tool def get_weather(locations: List[str], months: Optional[Tuple[str, str]] = None) -> Dict[str, float]: @@ -438,6 +582,23 @@ def get_weather(locations: List[str], months: Optional[Tuple[str, str]] = None) assert get_weather.inputs["locations"]["type"] == "array" assert get_weather.inputs["months"]["type"] == "array" + def test_method_tool_supports_array(self): + class ToolMethod: + @tool + def get_weather(self, locations: List[str], months: Optional[Tuple[str, str]] = None) -> Dict[str, float]: + """ + Get weather in the next days at given locations. + + Args: + locations: The locations to get the weather for. + months: The months to get the weather for + """ + return + + mytool = ToolMethod() + assert mytool.get_weather.inputs["locations"]["type"] == "array" + assert mytool.get_weather.inputs["months"]["type"] == "array" + def test_saving_tool_produces_valid_pyhon_code_with_multiline_description(self): @tool def get_weather(location: Any) -> None: