From daf654991fb78e00d2d95909ad1c8f09349c691d Mon Sep 17 00:00:00 2001 From: Aymeric Date: Wed, 11 Sep 2024 15:16:00 +0200 Subject: [PATCH 1/6] Decorator for tool building --- docs/source/en/agents.md | 74 ++++++++++--------------------- docs/source/en/agents_advanced.md | 62 +++++++++++++++++++++++++- src/transformers/agents/search.py | 58 +++++++++++------------- src/transformers/agents/tools.py | 31 ++++++++++++- tests/agents/test_tools_common.py | 17 +++++++ 5 files changed, 158 insertions(+), 84 deletions(-) diff --git a/docs/source/en/agents.md b/docs/source/en/agents.md index b100e39f1c95..a92e50530991 100644 --- a/docs/source/en/agents.md +++ b/docs/source/en/agents.md @@ -325,62 +325,35 @@ model = next(iter(list_models(filter=task, sort="downloads", direction=-1))) print(model.id) ``` -This code can be converted into a class that inherits from the [`Tool`] superclass. +This code can quickly be converted into a tool, just by wrapping it in a function and adding the `tool` decorator: -The custom tool needs: -- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`. -- An attribute `description` is used to populate the agent's system prompt. -- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input. -- An `output_type` attribute, which specifies the output type. -- A `forward` method which contains the inference code to be executed. - - -```python -from transformers import Tool -from huggingface_hub import list_models - -class HFModelDownloadsTool(Tool): - name = "model_download_counter" - description = ( - "This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. " - "It returns the name of the checkpoint." - ) - - inputs = { - "task": { - "type": "text", - "description": "the task category (such as text-classification, depth-estimation, etc)", - } - } - output_type = "text" - - def forward(self, task: str): - model = next(iter(list_models(filter=task, sort="downloads", direction=-1))) - return model.id -``` - -Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use. - - -```python -from model_downloads import HFModelDownloadsTool - -tool = HFModelDownloadsTool() -``` - -You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access. - -```python -tool.push_to_hub("{your_username}/hf-model-downloads") +```py +@tool +def model_download_counter(task: str) -> str: + """ + This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. + It returns the name of the checkpoint. + + Args: + task: The task for which + """ + model = next(iter(list_models(filter="text-classification", sort="downloads", direction=-1))) + return model.id ``` -Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent. +The function needs: +- A clear name. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's put `model_download_counter`. +- Type hints on both inputs and output +- A description, that includes an 'Args:' part where each argument is described (without a type indication this time, it will be pulled from the type hint). +All these will be automatically baked into the agent's system prompt upon initialization: so strive to make them as clear as possible! -```python -from transformers import load_tool, CodeAgent +> [!TIP] +> This definition format is the same as tool schemas used in `apply_chat_template`, the only difference is the added `tool` decorator: read more on our tool use API [here](https://huggingface.co/blog/unified-tool-use#passing-tools-to-a-chat-template). -model_download_tool = load_tool("m-ric/hf-model-downloads") +Then you can directly initialize your agent: +```py +from transformers import CodeAgent agent = CodeAgent(tools=[model_download_tool], llm_engine=llm_engine) agent.run( "Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?" @@ -400,7 +373,6 @@ print(f"The most downloaded model for the 'text-to-video' task is {most_download And the output: `"The most downloaded model for the 'text-to-video' task is ByteDance/AnimateDiff-Lightning."` - ### Manage your agent's toolbox If you have already initialized an agent, it is inconvenient to reinitialize it from scratch with a tool you want to use. With Transformers, you can manage an agent's toolbox by adding or replacing a tool. diff --git a/docs/source/en/agents_advanced.md b/docs/source/en/agents_advanced.md index e7469a310c41..e68afb4109c7 100644 --- a/docs/source/en/agents_advanced.md +++ b/docs/source/en/agents_advanced.md @@ -60,7 +60,67 @@ manager_agent.run("Who is the CEO of Hugging Face?") > For an in-depth example of an efficient multi-agent implementation, see [how we pushed our multi-agent system to the top of the GAIA leaderboard](https://huggingface.co/blog/beating-gaia). -## Use tools from gradio or LangChain +## Advanced tool usage + +### Directly define a tool by subclassing Tool, and share it to the Hub + +Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator. + +If you need to add variation, like custom attributes for your too, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass. + +The custom tool needs: +- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`. +- An attribute `description` is used to populate the agent's system prompt. +- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input. +- An `output_type` attribute, which specifies the output type. +- A `forward` method which contains the inference code to be executed. + +```python +from transformers import Tool +from huggingface_hub import list_models + +class HFModelDownloadsTool(Tool): + name = "model_download_counter" + description = ( + "This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. " + "It returns the name of the checkpoint." + ) + + inputs = { + "task": { + "type": "text", + "description": "the task category (such as text-classification, depth-estimation, etc)", + } + } + output_type = "text" + + def forward(self, task: str): + model = next(iter(list_models(filter=task, sort="downloads", direction=-1))) + return model.id +``` + +Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use. + + +```python +from model_downloads import HFModelDownloadsTool + +tool = HFModelDownloadsTool() +``` + +You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access. + +```python +tool.push_to_hub("{your_username}/hf-model-downloads") +``` + +Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent. + +```python +from transformers import load_tool, CodeAgent + +model_download_tool = load_tool("m-ric/hf-model-downloads") +``` ### Use gradio-tools diff --git a/src/transformers/agents/search.py b/src/transformers/agents/search.py index 5ce36bf99b56..3dab27fe5308 100644 --- a/src/transformers/agents/search.py +++ b/src/transformers/agents/search.py @@ -19,7 +19,7 @@ import requests from requests.exceptions import RequestException -from .tools import Tool +from .tools import Tool, tool class DuckDuckGoSearchTool(Tool): @@ -40,38 +40,34 @@ def forward(self, query: str) -> str: return results -class VisitWebpageTool(Tool): - name = "visit_webpage" - description = "Visits a wbepage at the given url and returns its content as a markdown string." - inputs = { - "url": { - "type": "text", - "description": "The url of the webpage to visit.", - } - } - output_type = "text" +@tool +def visit_webpage(url: str) -> str: + """ + Visits a webpage at the given url and returns its content as a markdown string. - def forward(self, url: str) -> str: - try: - from markdownify import markdownify - except ImportError: - raise ImportError( - "You must install package `markdownify` to run this tool: for instance run `pip install markdownify`." - ) - try: - # Send a GET request to the URL - response = requests.get(url) - response.raise_for_status() # Raise an exception for bad status codes + Args: + url: The url of the webpage to visit. + """ + try: + from markdownify import markdownify + except ImportError: + raise ImportError( + "You must install package `markdownify` to run this tool: for instance run `pip install markdownify`." + ) + try: + # Send a GET request to the URL + response = requests.get(url) + response.raise_for_status() # Raise an exception for bad status codes - # Convert the HTML content to Markdown - markdown_content = markdownify(response.text).strip() + # Convert the HTML content to Markdown + markdown_content = markdownify(response.text).strip() - # Remove multiple line breaks - markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) + # Remove multiple line breaks + markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) - return markdown_content + return markdown_content - except RequestException as e: - return f"Error fetching the webpage: {str(e)}" - except Exception as e: - return f"An unexpected error occurred: {str(e)}" + except RequestException as e: + return f"Error fetching the webpage: {str(e)}" + except Exception as e: + return f"An unexpected error occurred: {str(e)}" diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index f97ccc2e10ce..b6d81bb405a1 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -21,7 +21,7 @@ import os import tempfile from functools import lru_cache -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session @@ -36,6 +36,7 @@ from ..utils import ( CONFIG_NAME, cached_file, + get_json_schema, is_accelerate_available, is_torch_available, is_vision_available, @@ -808,3 +809,31 @@ def __init__(self, collection_slug: str, token: Optional[str] = None): self._collection = get_collection(collection_slug, token=token) self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"} self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids} + + +def tool(tool_function: Callable) -> Tool: + """ + Decorator that turns a function into an instance of a specific Tool subclass + + Args: + tool_function: Your function. Should have type hints for each input and for the output, and a description with 'Args:' part where each argument is described. + """ + parameters = get_json_schema(tool_function)["function"] + class_name = f"{parameters['name'].capitalize()}Tool" + specific_tool_class = type( + class_name, + (Tool,), + { + "name": parameters["name"], + "description": parameters["description"], + "inputs": parameters["parameters"], + "output_type": parameters["return"]["type"], + }, + ) + + def forward(self, *args, **kwargs): + return tool_function(*args, **kwargs) + + setattr(specific_tool_class, "forward", forward) + + return specific_tool_class() diff --git a/tests/agents/test_tools_common.py b/tests/agents/test_tools_common.py index 679473d0f24b..c8ca298e4740 100644 --- a/tests/agents/test_tools_common.py +++ b/tests/agents/test_tools_common.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import unittest from pathlib import Path from typing import Dict, Union @@ -19,6 +20,7 @@ from transformers import is_torch_available, is_vision_available from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText +from transformers.agents.tools import tool from transformers.testing_utils import get_tests_dir, is_agent_test @@ -100,3 +102,18 @@ def test_agent_types_inputs(self): for _input, expected_input in zip(inputs, self.tool.inputs.values()): input_type = expected_input["type"] _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) + + +class ToolTests(unittest.TestCase): + def test_tool_init_with_decorator(self): + @tool + def coolfunc(a: str, b: int) -> tuple: + """Cool function + + Args: + a: The first argument + b: The third one + """ + return b + 2, a + + assert coolfunc.output_type == "tuple" From 2447e7e24827c7c58fb08e7fcce688ad76a2c75c Mon Sep 17 00:00:00 2001 From: Aymeric Date: Wed, 11 Sep 2024 19:24:53 +0200 Subject: [PATCH 2/6] Fix tests --- docs/source/en/agents_advanced.md | 13 ++-- src/transformers/agents/agent_types.py | 2 +- src/transformers/agents/agents.py | 16 ++--- src/transformers/agents/default_tools.py | 11 ++-- .../agents/document_question_answering.py | 6 +- .../agents/image_question_answering.py | 4 +- src/transformers/agents/prompts.py | 2 +- src/transformers/agents/search.py | 60 ++++++++++--------- src/transformers/agents/speech_to_text.py | 2 +- src/transformers/agents/text_to_speech.py | 2 +- src/transformers/agents/tools.py | 13 ++-- src/transformers/agents/translation.py | 8 +-- src/transformers/utils/chat_template_utils.py | 11 +++- tests/agents/test_agents.py | 1 - tests/agents/test_final_answer.py | 11 ++-- tests/agents/test_python_interpreter.py | 8 +-- tests/agents/test_tools_common.py | 10 ++-- 17 files changed, 92 insertions(+), 88 deletions(-) diff --git a/docs/source/en/agents_advanced.md b/docs/source/en/agents_advanced.md index e68afb4109c7..733b52f2157c 100644 --- a/docs/source/en/agents_advanced.md +++ b/docs/source/en/agents_advanced.md @@ -75,24 +75,25 @@ The custom tool needs: - An `output_type` attribute, which specifies the output type. - A `forward` method which contains the inference code to be executed. +The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema). + ```python from transformers import Tool from huggingface_hub import list_models class HFModelDownloadsTool(Tool): name = "model_download_counter" - description = ( - "This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. " - "It returns the name of the checkpoint." - ) + description = """ + This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. + It returns the name of the checkpoint.""" inputs = { "task": { - "type": "text", + "type": "string", "description": "the task category (such as text-classification, depth-estimation, etc)", } } - output_type = "text" + output_type = "string" def forward(self, task: str): model = next(iter(list_models(filter=task, sort="downloads", direction=-1))) diff --git a/src/transformers/agents/agent_types.py b/src/transformers/agents/agent_types.py index 4a36eaaee051..f5be7462657c 100644 --- a/src/transformers/agents/agent_types.py +++ b/src/transformers/agents/agent_types.py @@ -234,7 +234,7 @@ def to_string(self): return self._path -AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio} +AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio} INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage} if is_torch_available(): diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 5a4aea28d970..446aa4f2508e 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -626,10 +626,9 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs): Example: ```py - from transformers.agents import CodeAgent, PythonInterpreterTool + from transformers.agents import CodeAgent - python_interpreter = PythonInterpreterTool() - agent = CodeAgent(tools=[python_interpreter]) + agent = CodeAgent(tools=[]) agent.run("What is the result of 2 power 3.7384?") ``` """ @@ -1019,20 +1018,17 @@ def step(self): arguments = {} observation = self.execute_tool_call(tool_name, arguments) observation_type = type(observation) - if observation_type == AgentText: - updated_information = str(observation).strip() - else: - # TODO: observation naming could allow for different names of same type + if observation_type in [AgentImage, AgentAudio]: if observation_type == AgentImage: observation_name = "image.png" elif observation_type == AgentAudio: observation_name = "audio.mp3" - else: - observation_name = "object.object" + # TODO: observation naming could allow for different names of same type self.state[observation_name] = observation updated_information = f"Stored '{observation_name}' in memory." - + else: + updated_information = str(observation).strip() self.logger.info(updated_information) current_step_logs["observation"] = updated_information return current_step_logs diff --git a/src/transformers/agents/default_tools.py b/src/transformers/agents/default_tools.py index b02b12d5287c..10097fa86743 100644 --- a/src/transformers/agents/default_tools.py +++ b/src/transformers/agents/default_tools.py @@ -25,7 +25,7 @@ from ..utils import is_offline_mode from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code -from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool +from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool, tool def custom_print(*args): @@ -152,8 +152,7 @@ class PythonInterpreterTool(Tool): name = "python_interpreter" description = "This is a tool that evaluates python code. It can be used to perform calculations." - output_type = "text" - available_tools = BASE_PYTHON_TOOLS.copy() + output_type = "string" def __init__(self, *args, authorized_imports=None, **kwargs): if authorized_imports is None: @@ -162,7 +161,7 @@ def __init__(self, *args, authorized_imports=None, **kwargs): self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports)) self.inputs = { "code": { - "type": "text", + "type": "string", "description": ( "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, " f"else you will get an error. This code can only import the following python libraries: {authorized_imports}." @@ -173,7 +172,7 @@ def __init__(self, *args, authorized_imports=None, **kwargs): def forward(self, code): output = str( - evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports) + evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports) ) return output @@ -181,7 +180,7 @@ def forward(self, code): class FinalAnswerTool(Tool): name = "final_answer" description = "Provides a final answer to the given problem." - inputs = {"answer": {"type": "text", "description": "The final answer to the problem"}} + inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}} output_type = "any" def forward(self, answer): diff --git a/src/transformers/agents/document_question_answering.py b/src/transformers/agents/document_question_answering.py index 030120ac6c7f..23ae5b042912 100644 --- a/src/transformers/agents/document_question_answering.py +++ b/src/transformers/agents/document_question_answering.py @@ -31,7 +31,7 @@ class DocumentQuestionAnsweringTool(PipelineTool): default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa" - description = "This is a tool that answers a question about an document (pdf). It returns a text that contains the answer to the question." + description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question." name = "document_qa" pre_processor_class = AutoProcessor model_class = VisionEncoderDecoderModel @@ -41,9 +41,9 @@ class DocumentQuestionAnsweringTool(PipelineTool): "type": "image", "description": "The image containing the information. Can be a PIL Image or a string path to the image.", }, - "question": {"type": "text", "description": "The question in English"}, + "question": {"type": "string", "description": "The question in English"}, } - output_type = "text" + output_type = "string" def __init__(self, *args, **kwargs): if not is_vision_available(): diff --git a/src/transformers/agents/image_question_answering.py b/src/transformers/agents/image_question_answering.py index 020d22c47f91..de0efb7b6f38 100644 --- a/src/transformers/agents/image_question_answering.py +++ b/src/transformers/agents/image_question_answering.py @@ -38,9 +38,9 @@ class ImageQuestionAnsweringTool(PipelineTool): "type": "image", "description": "The image containing the information. Can be a PIL Image or a string path to the image.", }, - "question": {"type": "text", "description": "The question in English"}, + "question": {"type": "string", "description": "The question in English"}, } - output_type = "text" + output_type = "string" def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) diff --git a/src/transformers/agents/prompts.py b/src/transformers/agents/prompts.py index de8ad1d28490..266d0d78796f 100644 --- a/src/transformers/agents/prompts.py +++ b/src/transformers/agents/prompts.py @@ -199,7 +199,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): Action: { "action": "image_generator", - "action_input": {"text": ""A portrait of John Doe, a 55-year-old man living in Canada.""} + "action_input": {"prompt": ""A portrait of John Doe, a 55-year-old man living in Canada.""} } Observation: "image.png" diff --git a/src/transformers/agents/search.py b/src/transformers/agents/search.py index 3dab27fe5308..4889de084711 100644 --- a/src/transformers/agents/search.py +++ b/src/transformers/agents/search.py @@ -26,7 +26,7 @@ class DuckDuckGoSearchTool(Tool): name = "web_search" description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements. Each result has keys 'title', 'href' and 'body'.""" - inputs = {"query": {"type": "text", "description": "The search query to perform."}} + inputs = {"query": {"type": "string", "description": "The search query to perform."}} output_type = "any" def forward(self, query: str) -> str: @@ -40,34 +40,38 @@ def forward(self, query: str) -> str: return results -@tool -def visit_webpage(url: str) -> str: - """ - Visits a webpage at the given url and returns its content as a markdown string. - - Args: - url: The url of the webpage to visit. - """ - try: - from markdownify import markdownify - except ImportError: - raise ImportError( - "You must install package `markdownify` to run this tool: for instance run `pip install markdownify`." - ) - try: - # Send a GET request to the URL - response = requests.get(url) - response.raise_for_status() # Raise an exception for bad status codes +class VisitWebpageTool(Tool): + name = "visit_webpage" + description = "Visits a wbepage at the given url and returns its content as a markdown string." + inputs = { + "url": { + "type": "string", + "description": "The url of the webpage to visit.", + } + } + output_type = "string" + + def forward(self, url: str) -> str: + try: + from markdownify import markdownify + except ImportError: + raise ImportError( + "You must install package `markdownify` to run this tool: for instance run `pip install markdownify`." + ) + try: + # Send a GET request to the URL + response = requests.get(url) + response.raise_for_status() # Raise an exception for bad status codes - # Convert the HTML content to Markdown - markdown_content = markdownify(response.text).strip() + # Convert the HTML content to Markdown + markdown_content = markdownify(response.text).strip() - # Remove multiple line breaks - markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) + # Remove multiple line breaks + markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) - return markdown_content + return markdown_content - except RequestException as e: - return f"Error fetching the webpage: {str(e)}" - except Exception as e: - return f"An unexpected error occurred: {str(e)}" + except RequestException as e: + return f"Error fetching the webpage: {str(e)}" + except Exception as e: + return f"An unexpected error occurred: {str(e)}" diff --git a/src/transformers/agents/speech_to_text.py b/src/transformers/agents/speech_to_text.py index 817b6319e6b8..8061651a0864 100644 --- a/src/transformers/agents/speech_to_text.py +++ b/src/transformers/agents/speech_to_text.py @@ -27,7 +27,7 @@ class SpeechToTextTool(PipelineTool): model_class = WhisperForConditionalGeneration inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}} - output_type = "text" + output_type = "string" def encode(self, audio): return self.pre_processor(audio, return_tensors="pt") diff --git a/src/transformers/agents/text_to_speech.py b/src/transformers/agents/text_to_speech.py index 3166fab8023c..ed41ef6017ae 100644 --- a/src/transformers/agents/text_to_speech.py +++ b/src/transformers/agents/text_to_speech.py @@ -36,7 +36,7 @@ class TextToSpeechTool(PipelineTool): model_class = SpeechT5ForTextToSpeech post_processor_class = SpeechT5HifiGan - inputs = {"text": {"type": "text", "description": "The text to read out loud (in English)"}} + inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}} output_type = "audio" def setup(self): diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index b6d81bb405a1..b907c890af17 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -383,7 +383,7 @@ def __init__(self, _gradio_tool): super().__init__() self.name = _gradio_tool.name self.description = _gradio_tool.description - self.output_type = "text" + self.output_type = "string" self._gradio_tool = _gradio_tool func_args = list(inspect.signature(_gradio_tool.run).parameters.keys()) self.inputs = {key: "" for key in func_args} @@ -405,7 +405,7 @@ def __init__(self, _langchain_tool): self.name = _langchain_tool.name.lower() self.description = _langchain_tool.description self.inputs = parse_langchain_args(_langchain_tool.args) - self.output_type = "text" + self.output_type = "string" self.langchain_tool = _langchain_tool def forward(self, *args, **kwargs): @@ -422,6 +422,7 @@ def forward(self, *args, **kwargs): DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """ - {{ tool.name }}: {{ tool.description }} Takes inputs: {{tool.inputs}} + Returns an output of type: {{tool.output_type}} """ @@ -622,18 +623,18 @@ def fn(*args, **kwargs): gradio_inputs = [] for input_name, input_details in tool_class.inputs.items(): input_type = input_details["type"] - if input_type == "text": - gradio_inputs.append(gr.Textbox(label=input_name)) - elif input_type == "image": + if input_type == "image": gradio_inputs.append(gr.Image(label=input_name)) elif input_type == "audio": gradio_inputs.append(gr.Audio(label=input_name)) + elif input_type in ["string", "integer", "number"]: + gradio_inputs.append(gr.Textbox(label=input_name)) else: error_message = f"Input type '{input_type}' not supported." raise ValueError(error_message) gradio_output = tool_class.output_type - assert gradio_output in ["text", "image", "audio"], f"Output type '{gradio_output}' not supported." + assert gradio_output in ["string", "image", "audio"], f"Output type '{gradio_output}' not supported." gr.Interface( fn=fn, diff --git a/src/transformers/agents/translation.py b/src/transformers/agents/translation.py index efc97c6e0b20..7ae61f9679b8 100644 --- a/src/transformers/agents/translation.py +++ b/src/transformers/agents/translation.py @@ -249,17 +249,17 @@ class TranslationTool(PipelineTool): model_class = AutoModelForSeq2SeqLM inputs = { - "text": {"type": "text", "description": "The text to translate"}, + "text": {"type": "string", "description": "The text to translate"}, "src_lang": { - "type": "text", + "type": "string", "description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'", }, "tgt_lang": { - "type": "text", + "type": "string", "description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'", }, } - output_type = "text" + output_type = "string" def encode(self, text, src_lang, tgt_lang): if src_lang not in self.lang_to_code: diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index aabaf4a36665..1ca8e59d926d 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -22,7 +22,7 @@ from packaging import version -from .import_utils import is_jinja_available +from .import_utils import is_jinja_available, is_vision_available, is_torch_available if is_jinja_available(): @@ -32,6 +32,12 @@ else: jinja2 = None +if is_vision_available(): + from PIL.Image import Image + +if is_torch_available(): + from torch import Tensor + BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) # Extracts the initial segment of the docstring, containing the function description @@ -63,13 +69,14 @@ class DocstringParsingException(Exception): pass - def _get_json_schema_type(param_type: str) -> Dict[str, str]: type_mapping = { int: {"type": "integer"}, float: {"type": "number"}, str: {"type": "string"}, bool: {"type": "boolean"}, + Image: {"type": "image"}, + Tensor: {"type": "audio"}, Any: {}, } return type_mapping.get(param_type, {"type": "object"}) diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 67cb31b7dac3..d0fc6879b6b1 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -68,7 +68,6 @@ def fake_react_code_llm(messages, stop_sequences=None, grammar=None) -> str: Code: ```py result = 2**3.6452 -print(result) ``` """ else: # We're at step 2 diff --git a/tests/agents/test_final_answer.py b/tests/agents/test_final_answer.py index 59d5dec84b57..3bd7c0f124fc 100644 --- a/tests/agents/test_final_answer.py +++ b/tests/agents/test_final_answer.py @@ -21,20 +21,17 @@ from transformers import is_torch_available, load_tool from transformers.agents.agent_types import AGENT_TYPE_MAPPING +from transformers.agents.default_tools import FinalAnswerTool from transformers.testing_utils import get_tests_dir, require_torch -from .test_tools_common import ToolTesterMixin - - if is_torch_available(): import torch -class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin): +class FinalAnswerToolTester(unittest.TestCase): def setUp(self): self.inputs = {"answer": "Final answer"} - self.tool = load_tool("final_answer") - self.tool.setup() + self.tool = FinalAnswerTool() def test_exact_match_arg(self): result = self.tool("Final answer") @@ -52,7 +49,7 @@ def create_inputs(self): ) } inputs_audio = {"answer": torch.Tensor(np.ones(3000))} - return {"text": inputs_text, "image": inputs_image, "audio": inputs_audio} + return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio} @require_torch def test_agent_type_output(self): diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index 84710cfec685..71dcbe3b641c 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -391,8 +391,9 @@ def test_if_conditions(self): code = """char='a' if char.isalpha(): print('2')""" - result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) - assert result == "2" + state = {} + result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) + assert state["print_outputs"] == "2\n" def test_imports(self): code = "import math\nmath.sqrt(4)" @@ -469,7 +470,7 @@ def test_print_output(self): code = "print('Hello world!')\nprint('Ok no one cares')" state = {} result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) - assert result == "Ok no one cares" + assert result == None assert state["print_outputs"] == "Hello world!\nOk no one cares\n" # test print in function @@ -594,7 +595,6 @@ def test_print(self): code = "print(min([1, 2, 3]))" state = {} result = evaluate_python_code(code, {"min": min, "print": print}, state=state) - assert result == "1" assert state["print_outputs"] == "1\n" def test_types_as_objects(self): diff --git a/tests/agents/test_tools_common.py b/tests/agents/test_tools_common.py index c8ca298e4740..521783a7b6c9 100644 --- a/tests/agents/test_tools_common.py +++ b/tests/agents/test_tools_common.py @@ -31,7 +31,7 @@ from PIL import Image -AUTHORIZED_TYPES = ["text", "audio", "image", "any"] +AUTHORIZED_TYPES = ["string", "number", "audio", "image", "any"] def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]): @@ -40,7 +40,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]): for input_name, input_desc in tool_inputs.items(): input_type = input_desc["type"] - if input_type == "text": + if input_type == "string": inputs[input_name] = "Text input" elif input_type == "image": inputs[input_name] = Image.open( @@ -56,7 +56,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]): def output_type(output): if isinstance(output, (str, AgentText)): - return "text" + return "string" elif isinstance(output, (Image.Image, AgentImage)): return "image" elif isinstance(output, (torch.Tensor, AgentAudio)): @@ -107,7 +107,7 @@ def test_agent_types_inputs(self): class ToolTests(unittest.TestCase): def test_tool_init_with_decorator(self): @tool - def coolfunc(a: str, b: int) -> tuple: + def coolfunc(a: str, b: int) -> float: """Cool function Args: @@ -116,4 +116,4 @@ def coolfunc(a: str, b: int) -> tuple: """ return b + 2, a - assert coolfunc.output_type == "tuple" + assert coolfunc.output_type == "number" From 98f5463fee0009538eb1ca9f7e466c6145e3a731 Mon Sep 17 00:00:00 2001 From: Aymeric Date: Wed, 11 Sep 2024 19:27:42 +0200 Subject: [PATCH 3/6] Formatting --- src/transformers/agents/agents.py | 2 +- src/transformers/agents/default_tools.py | 2 +- src/transformers/agents/prompts.py | 2 +- src/transformers/agents/search.py | 4 ++-- src/transformers/utils/chat_template_utils.py | 3 ++- tests/agents/test_final_answer.py | 3 ++- tests/agents/test_python_interpreter.py | 6 +++--- tests/agents/test_tools_common.py | 2 +- 8 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 446aa4f2508e..73b7186d25a3 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -22,7 +22,7 @@ from .. import is_torch_available from ..utils import logging as transformers_logging from ..utils.import_utils import is_pygments_available -from .agent_types import AgentAudio, AgentImage, AgentText +from .agent_types import AgentAudio, AgentImage from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools from .llm_engine import HfApiEngine, MessageRole from .prompts import ( diff --git a/src/transformers/agents/default_tools.py b/src/transformers/agents/default_tools.py index 10097fa86743..3946aa9f8735 100644 --- a/src/transformers/agents/default_tools.py +++ b/src/transformers/agents/default_tools.py @@ -25,7 +25,7 @@ from ..utils import is_offline_mode from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code -from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool, tool +from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool def custom_print(*args): diff --git a/src/transformers/agents/prompts.py b/src/transformers/agents/prompts.py index 266d0d78796f..7a84b1db44fa 100644 --- a/src/transformers/agents/prompts.py +++ b/src/transformers/agents/prompts.py @@ -199,7 +199,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): Action: { "action": "image_generator", - "action_input": {"prompt": ""A portrait of John Doe, a 55-year-old man living in Canada.""} + "action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} } Observation: "image.png" diff --git a/src/transformers/agents/search.py b/src/transformers/agents/search.py index 4889de084711..f50a7c6ab8f9 100644 --- a/src/transformers/agents/search.py +++ b/src/transformers/agents/search.py @@ -19,7 +19,7 @@ import requests from requests.exceptions import RequestException -from .tools import Tool, tool +from .tools import Tool class DuckDuckGoSearchTool(Tool): @@ -50,7 +50,7 @@ class VisitWebpageTool(Tool): } } output_type = "string" - + def forward(self, url: str) -> str: try: from markdownify import markdownify diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 1ca8e59d926d..74912ce30146 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -22,7 +22,7 @@ from packaging import version -from .import_utils import is_jinja_available, is_vision_available, is_torch_available +from .import_utils import is_jinja_available, is_torch_available, is_vision_available if is_jinja_available(): @@ -69,6 +69,7 @@ class DocstringParsingException(Exception): pass + def _get_json_schema_type(param_type: str) -> Dict[str, str]: type_mapping = { int: {"type": "integer"}, diff --git a/tests/agents/test_final_answer.py b/tests/agents/test_final_answer.py index 3bd7c0f124fc..ee744bbd7bd3 100644 --- a/tests/agents/test_final_answer.py +++ b/tests/agents/test_final_answer.py @@ -19,11 +19,12 @@ import numpy as np from PIL import Image -from transformers import is_torch_available, load_tool +from transformers import is_torch_available from transformers.agents.agent_types import AGENT_TYPE_MAPPING from transformers.agents.default_tools import FinalAnswerTool from transformers.testing_utils import get_tests_dir, require_torch + if is_torch_available(): import torch diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index 71dcbe3b641c..15e5ad7bb3a3 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -392,7 +392,7 @@ def test_if_conditions(self): if char.isalpha(): print('2')""" state = {} - result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) + evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) assert state["print_outputs"] == "2\n" def test_imports(self): @@ -470,7 +470,7 @@ def test_print_output(self): code = "print('Hello world!')\nprint('Ok no one cares')" state = {} result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) - assert result == None + assert result is None assert state["print_outputs"] == "Hello world!\nOk no one cares\n" # test print in function @@ -594,7 +594,7 @@ def method_that_raises(self): def test_print(self): code = "print(min([1, 2, 3]))" state = {} - result = evaluate_python_code(code, {"min": min, "print": print}, state=state) + evaluate_python_code(code, {"min": min, "print": print}, state=state) assert state["print_outputs"] == "1\n" def test_types_as_objects(self): diff --git a/tests/agents/test_tools_common.py b/tests/agents/test_tools_common.py index 521783a7b6c9..47770669c3eb 100644 --- a/tests/agents/test_tools_common.py +++ b/tests/agents/test_tools_common.py @@ -31,7 +31,7 @@ from PIL import Image -AUTHORIZED_TYPES = ["string", "number", "audio", "image", "any"] +AUTHORIZED_TYPES = ["string", "boolean", "integer", "number", "audio", "image", "any"] def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]): From d6f703f62021c3cbb732ba1c15b8afbda62bfb1f Mon Sep 17 00:00:00 2001 From: Aymeric Date: Fri, 13 Sep 2024 18:51:40 +0200 Subject: [PATCH 4/6] Re-add ToolTesterMixin --- tests/agents/test_final_answer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/agents/test_final_answer.py b/tests/agents/test_final_answer.py index ee744bbd7bd3..91bdd65e89a8 100644 --- a/tests/agents/test_final_answer.py +++ b/tests/agents/test_final_answer.py @@ -24,12 +24,14 @@ from transformers.agents.default_tools import FinalAnswerTool from transformers.testing_utils import get_tests_dir, require_torch +from .test_tools_common import ToolTesterMixin + if is_torch_available(): import torch -class FinalAnswerToolTester(unittest.TestCase): +class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin): def setUp(self): self.inputs = {"answer": "Final answer"} self.tool = FinalAnswerTool() From 66d5283d254b901a83aafec1f2253ca821081630 Mon Sep 17 00:00:00 2001 From: Aymeric Date: Mon, 16 Sep 2024 12:22:15 +0200 Subject: [PATCH 5/6] Finalize tool checks --- src/transformers/agents/tools.py | 78 +++++++++++++++++++++++-------- tests/agents/test_agents.py | 3 +- tests/agents/test_tools_common.py | 56 +++++++++++++++++++++- 3 files changed, 114 insertions(+), 23 deletions(-) diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index b907c890af17..46a9a2fa1050 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -16,11 +16,12 @@ # limitations under the License. import base64 import importlib +import inspect import io import json import os import tempfile -from functools import lru_cache +from functools import lru_cache, wraps from typing import Any, Callable, Dict, List, Optional, Union from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder @@ -35,6 +36,7 @@ from ..models.auto import AutoProcessor from ..utils import ( CONFIG_NAME, + TypeHintParsingException, cached_file, get_json_schema, is_accelerate_available, @@ -85,6 +87,20 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs): """ +def validate_after_init(cls): + original_init = cls.__init__ + + @wraps(original_init) + def new_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + if not isinstance(self, PipelineTool): + self.validate_arguments() + + cls.__init__ = new_init + return cls + + +@validate_after_init class Tool: """ A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the @@ -115,17 +131,35 @@ class Tool: def __init__(self, *args, **kwargs): self.is_initialized = False - def validate_attributes(self): + def validate_arguments(self): required_attributes = { "description": str, "name": str, "inputs": Dict, - "output_type": type, + "output_type": str, } + authorized_types = ["string", "integer", "number", "image", "audio", "any"] + for attr, expected_type in required_attributes.items(): attr_value = getattr(self, attr, None) if not isinstance(attr_value, expected_type): - raise TypeError(f"Instance attribute {attr} must exist and be of type {expected_type.__name__}") + raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.") + for input_name, input_content in self.inputs.items(): + assert "type" in input_content, f"Input '{input_name}' should specify a type." + if input_content["type"] not in authorized_types: + raise Exception( + f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}." + ) + assert "description" in input_content, f"Input '{input_name}' should have a description." + + assert getattr(self, "output_type", None) in authorized_types + + if not isinstance(self, PipelineTool): + signature = inspect.signature(self.forward) + if not set(signature.parameters.keys()) == set(self.inputs.keys()): + raise Exception( + "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." + ) def forward(self, *args, **kwargs): return NotImplemented("Write this method in your subclass of `Tool`.") @@ -817,24 +851,30 @@ def tool(tool_function: Callable) -> Tool: Decorator that turns a function into an instance of a specific Tool subclass Args: - tool_function: Your function. Should have type hints for each input and for the output, and a description with 'Args:' part where each argument is described. + tool_function: Your function. Should have type hints for each input and a type hint for the output. + Should also have a docstring description including an 'Args:' part where each argument is described. """ parameters = get_json_schema(tool_function)["function"] + if "return" not in parameters: + raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!") class_name = f"{parameters['name'].capitalize()}Tool" - specific_tool_class = type( - class_name, - (Tool,), - { - "name": parameters["name"], - "description": parameters["description"], - "inputs": parameters["parameters"], - "output_type": parameters["return"]["type"], - }, - ) - def forward(self, *args, **kwargs): - return tool_function(*args, **kwargs) + class SpecificTool(Tool): + name = parameters["name"] + description = parameters["description"] + inputs = parameters["parameters"]["properties"] + output_type = parameters["return"]["type"] + + @wraps(tool_function) + def forward(self, *args, **kwargs): + return tool_function(*args, **kwargs) - setattr(specific_tool_class, "forward", forward) + original_signature = inspect.signature(tool_function) + new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list( + original_signature.parameters.values() + ) + new_signature = original_signature.replace(parameters=new_parameters) + SpecificTool.forward.__signature__ = new_signature - return specific_tool_class() + SpecificTool.__name__ = class_name + return SpecificTool() diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index d0fc6879b6b1..4f24abbeedd8 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -180,7 +180,6 @@ def test_fake_react_code_agent(self): assert isinstance(output, float) assert output == 7.2904 assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?" - assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6 assert agent.logs[2]["tool_call"] == { "tool_arguments": "final_answer(7.2904)", "tool_name": "code interpreter", @@ -233,7 +232,7 @@ def test_init_agent_with_different_toolsets(self): # check that python_interpreter base tool does not get added to code agents agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) - assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter) + assert len(agent.toolbox.tools) == 7 # added final_answer tool + 6 base tools (excluding interpreter) def test_function_persistence_across_steps(self): agent = ReactCodeAgent( diff --git a/tests/agents/test_tools_common.py b/tests/agents/test_tools_common.py index 47770669c3eb..8226e7109884 100644 --- a/tests/agents/test_tools_common.py +++ b/tests/agents/test_tools_common.py @@ -17,10 +17,11 @@ from typing import Dict, Union import numpy as np +import pytest from transformers import is_torch_available, is_vision_available from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText -from transformers.agents.tools import tool +from transformers.agents.tools import Tool, tool from transformers.testing_utils import get_tests_dir, is_agent_test @@ -112,8 +113,59 @@ def coolfunc(a: str, b: int) -> float: Args: a: The first argument - b: The third one + b: The second one """ return b + 2, a assert coolfunc.output_type == "number" + + def test_tool_init_vanilla(self): + class HFModelDownloadsTool(Tool): + name = "model_download_counter" + description = """ + This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. + It returns the name of the checkpoint.""" + + inputs = { + "task": { + "type": "string", + "description": "the task category (such as text-classification, depth-estimation, etc)", + } + } + output_type = "integer" + + def forward(self, task): + return "best model" + + tool = HFModelDownloadsTool() + assert list(tool.inputs.keys())[0] == "task" + + def test_tool_init_decorator_raises_issues(self): + with pytest.raises(Exception) as e: + + @tool + def coolfunc(a: str, b: int): + """Cool function + + Args: + a: The first argument + b: The second one + """ + return a + b + + assert coolfunc.output_type == "number" + assert "Tool return type not found" in str(e) + + with pytest.raises(Exception) as e: + + @tool + def coolfunc(a: str, b: int) -> int: + """Cool function + + Args: + a: The first argument + """ + return b + a + + assert coolfunc.output_type == "number" + assert "docstring has no description for the argument" in str(e) From 1004967f2569b3859dae871a590185fa5817e240 Mon Sep 17 00:00:00 2001 From: Aymeric Date: Wed, 18 Sep 2024 10:01:36 +0200 Subject: [PATCH 6/6] Add tool to init files --- docs/source/en/agents.md | 2 ++ docs/source/en/main_classes/agent.md | 4 ++++ src/transformers/__init__.py | 2 ++ src/transformers/agents/__init__.py | 4 ++-- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/source/en/agents.md b/docs/source/en/agents.md index 137ec9d646eb..ac06c04d9baa 100644 --- a/docs/source/en/agents.md +++ b/docs/source/en/agents.md @@ -329,6 +329,8 @@ This code can quickly be converted into a tool, just by wrapping it in a functio ```py +from transformers import tool + @tool def model_download_counter(task: str) -> str: """ diff --git a/docs/source/en/main_classes/agent.md b/docs/source/en/main_classes/agent.md index 8628785815ce..ed0486b60128 100644 --- a/docs/source/en/main_classes/agent.md +++ b/docs/source/en/main_classes/agent.md @@ -60,6 +60,10 @@ We provide two types of agents, based on the main [`Agent`] class: [[autodoc]] load_tool +### tool + +[[autodoc]] tool + ### Tool [[autodoc]] Tool diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 36775d8454ab..bfd0d37916b5 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -70,6 +70,7 @@ "launch_gradio_demo", "load_tool", "stream_to_gradio", + "tool", ], "audio_utils": [], "benchmark": [], @@ -4819,6 +4820,7 @@ launch_gradio_demo, load_tool, stream_to_gradio, + tool, ) from .configuration_utils import PretrainedConfig diff --git a/src/transformers/agents/__init__.py b/src/transformers/agents/__init__.py index d053e385cf7a..70762c252a83 100644 --- a/src/transformers/agents/__init__.py +++ b/src/transformers/agents/__init__.py @@ -27,7 +27,7 @@ "agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"], "llm_engine": ["HfApiEngine", "TransformersEngine"], "monitoring": ["stream_to_gradio"], - "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"], + "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"], } try: @@ -48,7 +48,7 @@ from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox from .llm_engine import HfApiEngine, TransformersEngine from .monitoring import stream_to_gradio - from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool + from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool try: if not is_torch_available():