Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method support for tool decorator #627

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
41 changes: 41 additions & 0 deletions docs/source/en/guided_tour.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,47 @@ Out[20]: 'ByteDance/AnimateDiff-Lightning'
> [!TIP]
> Read more on tools in the [dedicated tutorial](./tutorials/tools#what-is-a-tool-and-how-to-build-one).


### Tools from methods
Sometimes, you need to preserve state or perform an initialization step, such as handling tokens, setting up API clients, or sharing resources among tools. In these cases, you can encapsulate a tool within a class method.

You still use the @tool decorator, but you apply it directly to a method. Let's modify our previous example to handle authentication using a HuggingFace token:
```python
from smolagents import tool

class DownloadTools:
def __init__(self, hf_token: str):
self.hf_token = hf_token

@tool
def model_download_tool(self, task: str) -> str:
"""
Returns the most downloaded model of a given task on the Hugging Face Hub.
Uses the authentication token from initialization.

Args:
task: The task for which to get the most downloaded model.
"""
most_downloaded_model = next(iter(
list_models(
filter=task,
sort="downloads",
direction=-1,
token=self.hf_token # Pass the token from init
)
))
return most_downloaded_model.id

# Usage:
dl_tools = DownloadTools(hf_token="hf_xxxxxxxx")
agent = CodeAgent(tools=[dl_tools.model_download_tool], model=HfApiModel())
```

> [!TIP]
> You can create multiple tools in a single class by decorating more methods with @tool.
> You can also mix class-based and function-based tools in the same agent. Go with whichever approach suits your needs best!


## Multi-agents

Multi-agent systems have been introduced with Microsoft's framework [Autogen](https://huggingface.co/papers/2308.08155).
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/tutorials/tools.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ The custom tool subclasses [`Tool`] to inherit useful methods. The child class a

And that's all it needs to be used in an agent!

There's another way to build a tool. In the [guided_tour](../guided_tour), we implemented a tool using the `@tool` decorator. The [`tool`] decorator is the recommended way to define simple tools, but sometimes you need more than this: using several methods in a class for more clarity, or using additional class attributes.
There's another way to build a tool. In the [guided_tour](../guided_tour), we implemented a tool using the `@tool` decorator on a function The [`tool`] decorator is the recommended way to define simple tools, but sometimes you need more than this: using several methods in a class for more clarity, or using additional class attributes.

In this case, you can build your tool by subclassing [`Tool`] as described above.

Expand Down
17 changes: 12 additions & 5 deletions src/smolagents/_function_type_hints_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ class DocstringParsingException(Exception):
"""Exception raised for errors in parsing docstrings to generate JSON schemas"""


def get_json_schema(func: Callable) -> Dict:
class ToolCreationException(Exception):
"""Exception raised for errors in creating a tool from a function or method"""


def get_json_schema(func: Callable, skip_self_cls: bool = False) -> Dict:
"""
This function generates a JSON schema for a given function, based on its docstring and type hints. This is
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
Expand All @@ -96,6 +100,7 @@ def get_json_schema(func: Callable) -> Dict:

Args:
func: The function to generate a JSON schema for.
skip_self_cls: Whether to skip the `self` and `cls` arguments in the schema. This is useful for bound methods in classes.

Returns:
A dictionary containing the JSON schema for the function.
Expand Down Expand Up @@ -197,8 +202,7 @@ def get_json_schema(func: Callable) -> Dict:
)
doc = doc.strip()
main_doc, param_descriptions, return_doc = _parse_google_format_docstring(doc)

json_schema = _convert_type_hints_to_json_schema(func)
json_schema = _convert_type_hints_to_json_schema(func, skip_self_cls=skip_self_cls)
if (return_dict := json_schema["properties"].pop("return", None)) is not None:
if return_doc is not None: # We allow a missing return docstring since most templates ignore it
return_dict["description"] = return_doc
Expand Down Expand Up @@ -273,16 +277,19 @@ def _parse_google_format_docstring(
return description, args_dict, returns


def _convert_type_hints_to_json_schema(func: Callable, error_on_missing_type_hints: bool = True) -> Dict:
def _convert_type_hints_to_json_schema(
func: Callable, error_on_missing_type_hints: bool = True, skip_self_cls: bool = False
) -> Dict:
type_hints = get_type_hints(func)
signature = inspect.signature(func)

properties = {}
for param_name, param_type in type_hints.items():
properties[param_name] = _parse_type_hint(param_type)

required = []
for param_name, param in signature.parameters.items():
if skip_self_cls and param_name in ["self", "cls"]:
continue
if param.annotation == inspect.Parameter.empty and error_on_missing_type_hints:
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
if param_name not in properties:
Expand Down
55 changes: 52 additions & 3 deletions src/smolagents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from huggingface_hub.utils import is_torch_available

from ._function_type_hints_utils import (
ToolCreationException,
TypeHintParsingException,
_convert_type_hints_to_json_schema,
get_imports,
Expand Down Expand Up @@ -250,9 +251,11 @@ def replacement(match):
"SpaceToolWrapper",
"LangChainToolWrapper",
"GradioToolWrapper",
"MethodToolWrapper",
]:
raise ValueError(
"Cannot save objects created with from_space, from_langchain or from_gradio, as this would create errors."
"Cannot save objects created with from_space, from_langchain, from_gradio, "
"or class methods decorated with @tool as this would create errors."
)

validate_tool_attributes(self.__class__)
Expand Down Expand Up @@ -838,9 +841,28 @@ def tool(tool_function: Callable) -> Tool:
Converts a function into an instance of a Tool subclass.

Args:
tool_function: Your function. Should have type hints for each input and a type hint for the output.
tool_function: Your function or class method. Should have type hints for each input and a type hint for the output.
Should also have a docstring description including an 'Args:' part where each argument is described.
"""
original_signature = inspect.signature(tool_function)
params = list(original_signature.parameters.values())

is_method = ( # inspect.ismethod won't work here because the decorator is applied before the method is bound
params
and params[0].name in ("self", "cls")
and hasattr(tool_function, "__qualname__")
and "." in tool_function.__qualname__
)

if is_method:
return _tool_from_method(tool_function, original_signature)
elif inspect.isfunction(tool_function):
return _tool_from_function(tool_function, original_signature)
else:
raise ToolCreationException("Tool decorator can only be used on functions or methods.")


def _tool_from_function(tool_function: Callable, original_signature: inspect.Signature) -> Tool:
tool_json_schema = get_json_schema(tool_function)["function"]
if "return" not in tool_json_schema:
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
Expand Down Expand Up @@ -868,7 +890,6 @@ def __init__(
output_type=tool_json_schema["return"]["type"],
function=tool_function,
)
original_signature = inspect.signature(tool_function)
new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY)] + list(
original_signature.parameters.values()
)
Expand All @@ -877,6 +898,34 @@ def __init__(
return simple_tool


def _tool_from_method(tool_function: Callable, original_signature: inspect.Signature) -> Tool:
tool_json_schema = get_json_schema(tool_function, skip_self_cls=True)["function"]
if "return" not in tool_json_schema:
raise TypeHintParsingException("Tool return type not found: make sure your method has a return type hint!")

class MethodToolWrapper(Tool):
def __init__(self):
self.name = tool_json_schema["name"]
self.description = tool_json_schema["description"]
self.inputs = tool_json_schema["parameters"]["properties"]
self.output_type = tool_json_schema["return"]["type"]
self.is_initialized = True
# Temp forward function; will be replaced with a bound instance in __get__. (this is for sig validation)
self.forward = lambda: 0
new_parameters = [v for v in original_signature.parameters.values() if v.name not in ["self", "cls"]]
modified_sig = original_signature.replace(parameters=new_parameters)
self.forward.__signature__ = modified_sig

# When the decorator is first used, it's on an unbound class method.
# So we have to use a descriptor here to bind the instance to the forward method when accessed.
def __get__(self, instance, owner):
func = getattr(tool_function, "__func__", tool_function)
self.forward = lambda *args, **kwargs: func(instance, *args, **kwargs)
return self

return MethodToolWrapper()


class PipelineTool(Tool):
"""
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
Expand Down
2 changes: 1 addition & 1 deletion tests/test_all_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def setup_class(cls):

load_dotenv()

cls.md_files = list(cls.docs_dir.rglob("*.md"))
cls.md_files = list(cls.docs_dir.rglob("*.mdx"))
if not cls.md_files:
raise ValueError(f"No markdown files found in {cls.docs_dir}")

Expand Down
Loading