From 6db07c973836d3768be5e7126c4ba0c2ebcc16d4 Mon Sep 17 00:00:00 2001 From: Aymeric Date: Fri, 1 Nov 2024 13:40:52 +0100 Subject: [PATCH 1/3] Agents: load a Space as a tool --- docs/source/en/agents_advanced.md | 38 +++++++++- src/transformers/agents/tools.py | 111 ++++++++++++++++++++++-------- 2 files changed, 117 insertions(+), 32 deletions(-) diff --git a/docs/source/en/agents_advanced.md b/docs/source/en/agents_advanced.md index ddcc619b4f91..20e27de2be7f 100644 --- a/docs/source/en/agents_advanced.md +++ b/docs/source/en/agents_advanced.md @@ -123,6 +123,40 @@ from transformers import load_tool, CodeAgent model_download_tool = load_tool("m-ric/hf-model-downloads") ``` +### Import a Space as a tool 🚀 + +You can directly import a Space from the Hub as a tool using the [`Tool.from_space`] method! + +You only need to provide the id of the Space on the Hub, its name, and a description that will help you agent understand what the tool does. Under the hood, this will use [`gradio-client`](https://pypi.org/project/gradio-client/) library to call the Space. + +For instance, let's import the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) Space from the Hub and use it to generate an image. + +``` +from transformers import Tool + +image_generation_tool = Tool.from_space( + "black-forest-labs/FLUX.1-dev", + name="image_generator", + description="Generate an image from a prompt") + +image_generation_tool("A sunny beach") +``` + +And voilà, here's your image! 🏖️ + + + +Then you can use this tool just like any other tool. For example, let's improve the prompt and generate an image of it. +```python +from transformers import ReactCodeAgent + +agent = ReactCodeAgent(tools=[image_generation_tool]) + +agent.run( + "Improve this prompt, then generate an image of it.", prompt='A rabbit wearing a space suit' +) +``` + ### Use gradio-tools [gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging @@ -179,7 +213,7 @@ We love Langchain and think it has a very compelling suite of tools. To import a tool from LangChain, use the `from_langchain()` method. Here is how you can use it to recreate the intro's search result using a LangChain web search tool. - +This tool will need `pip install google-search-results` to work properly. ```python from langchain.agents import load_tools from transformers import Tool, ReactCodeAgent @@ -188,7 +222,7 @@ search_tool = Tool.from_langchain(load_tools(["serpapi"])[0]) agent = ReactCodeAgent(tools=[search_tool]) -agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?") +agent.run("How many more blocks (also denoted as layers) are in BERT base encoder compared to the encoder from the architecture proposed in Attention is All You Need?") ``` ## Display your agent run in a cool Gradio interface diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index a425ffc8f106..51f7c0a4bc68 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -87,20 +87,20 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs): """ -def validate_after_init(cls): +def validate_after_init(cls, do_validate_forward: bool = True): original_init = cls.__init__ @wraps(original_init) def new_init(self, *args, **kwargs): original_init(self, *args, **kwargs) if not isinstance(self, PipelineTool): - self.validate_arguments() + self.validate_arguments(do_validate_forward=do_validate_forward) cls.__init__ = new_init return cls +CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"} -@validate_after_init class Tool: """ A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the @@ -131,7 +131,12 @@ class Tool: def __init__(self, *args, **kwargs): self.is_initialized = False - def validate_arguments(self): + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + validate_after_init(cls, do_validate_forward=False) + + + def validate_arguments(self, do_validate_forward: bool = True): required_attributes = { "description": str, "name": str, @@ -145,21 +150,21 @@ def validate_arguments(self): if not isinstance(attr_value, expected_type): raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.") for input_name, input_content in self.inputs.items(): - assert "type" in input_content, f"Input '{input_name}' should specify a type." + assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary." + assert "type" in input_content and "description" in input_content, f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." if input_content["type"] not in authorized_types: raise Exception( f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}." ) - assert "description" in input_content, f"Input '{input_name}' should have a description." assert getattr(self, "output_type", None) in authorized_types - - if not isinstance(self, PipelineTool): - signature = inspect.signature(self.forward) - if not set(signature.parameters.keys()) == set(self.inputs.keys()): - raise Exception( - "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." - ) + if do_validate_forward: + if not isinstance(self, PipelineTool): + signature = inspect.signature(self.forward) + if not set(signature.parameters.keys()) == set(self.inputs.keys()): + raise Exception( + "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." + ) def forward(self, *args, **kwargs): return NotImplemented("Write this method in your subclass of `Tool`.") @@ -404,6 +409,61 @@ def push_to_hub( create_pr=create_pr, repo_type="space", ) + + @staticmethod + def from_space(space_id, name, description): + """ + Creates a [`Tool`] from a Space given its id on the Hub. + + Args: + space_id (`str`): + The id of the Space on the Hub. + name (`str`): + The name of the tool. + description (`str`): + The description of the tool. + + Returns: + [`Tool`]: + The created tool. + + Example: + ``` + tool = Tool.from_space("black-forest-labs/FLUX.1-schnell", "image-generator", "Generate an image from a prompt") + ``` + """ + from gradio_client import Client + + class SpaceToolWrapper(Tool): + def __init__(self, space_id, name, description): + self.client = Client(space_id) + self.name = name + self.description = description + space_description = self.client.view_api(return_format="dict")[ + "named_endpoints" + ] + route = list(space_description.keys())[0] + space_description_route = space_description[route] + self.inputs = {} + for parameter in space_description_route["parameters"]: + if not parameter["parameter_has_default"]: + self.inputs[parameter["parameter_name"]] = { + "type": parameter["type"]["type"], + "description": parameter["python_type"]["description"], + } + output_component = space_description_route["returns"][0]["component"] + if output_component == "Image": + self.output_type = "image" + elif output_component == "Audio": + self.output_type = "audio" + else: + self.output_type = "any" + + def forward(self, *args, **kwargs): + return self.client.predict(*args, **kwargs)[0] # Usually the first output is the result + + return SpaceToolWrapper(space_id, name, description) + @staticmethod def from_gradio(gradio_tool): @@ -414,16 +474,13 @@ def from_gradio(gradio_tool): class GradioToolWrapper(Tool): def __init__(self, _gradio_tool): - super().__init__() self.name = _gradio_tool.name self.description = _gradio_tool.description self.output_type = "string" self._gradio_tool = _gradio_tool - func_args = list(inspect.signature(_gradio_tool.run).parameters.keys()) - self.inputs = {key: "" for key in func_args} - - def forward(self, *args, **kwargs): - return self._gradio_tool.run(*args, **kwargs) + func_args = list(inspect.signature(_gradio_tool.run).parameters.items()) + self.inputs = {key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args} + self.forward = self._gradio_tool.run return GradioToolWrapper(gradio_tool) @@ -435,10 +492,13 @@ def from_langchain(langchain_tool): class LangChainToolWrapper(Tool): def __init__(self, _langchain_tool): - super().__init__() self.name = _langchain_tool.name.lower() self.description = _langchain_tool.description - self.inputs = parse_langchain_args(_langchain_tool.args) + self.inputs = _langchain_tool.args.copy() + for input_content in self.inputs.values(): + if "title" in input_content: + input_content.pop("title") + input_content["description"] = "" self.output_type = "string" self.langchain_tool = _langchain_tool @@ -805,15 +865,6 @@ def __call__( return response.json() -def parse_langchain_args(args: Dict[str, str]) -> Dict[str, str]: - """Parse the args attribute of a LangChain tool to create a matching inputs dictionary.""" - inputs = args.copy() - for arg_details in inputs.values(): - if "title" in arg_details: - arg_details.pop("title") - return inputs - - class ToolCollection: """ Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox. From c5f36f4b6b7d1485ce3c92199475bc454e4efd6a Mon Sep 17 00:00:00 2001 From: Aymeric Date: Fri, 1 Nov 2024 13:58:08 +0100 Subject: [PATCH 2/3] Use cooler rabbit image --- docs/source/ar/agents.md | 2 +- docs/source/en/agents_advanced.md | 48 +++++++++++-------------------- 2 files changed, 17 insertions(+), 33 deletions(-) diff --git a/docs/source/ar/agents.md b/docs/source/ar/agents.md index 92b2a4715f6f..1213b3500860 100644 --- a/docs/source/ar/agents.md +++ b/docs/source/ar/agents.md @@ -464,7 +464,7 @@ image = image_generator(prompt=improved_prompt) قبل إنشاء الصورة أخيرًا: - + > [!WARNING] > تتطلب gradio-tools إدخالات وإخراجات *نصية* حتى عند العمل مع طرائق مختلفة مثل كائنات الصور والصوت. الإدخالات والإخراجات الصورية والصوتية غير متوافقة حاليًا. diff --git a/docs/source/en/agents_advanced.md b/docs/source/en/agents_advanced.md index 20e27de2be7f..e80e402d7374 100644 --- a/docs/source/en/agents_advanced.md +++ b/docs/source/en/agents_advanced.md @@ -141,12 +141,12 @@ image_generation_tool = Tool.from_space( image_generation_tool("A sunny beach") ``` - And voilà, here's your image! 🏖️ -Then you can use this tool just like any other tool. For example, let's improve the prompt and generate an image of it. +Then you can use this tool just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit` and generate an image of it. + ```python from transformers import ReactCodeAgent @@ -157,6 +157,20 @@ agent.run( ) ``` +```text +=== Agent thoughts: +improved_prompt could be "A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background" + +Now that I have improved the prompt, I can use the image generator tool to generate an image based on this prompt. +>>> Agent is executing the code below: +image = image_generator(prompt="A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background") +final_answer(image) +``` + + + +How cool is this? 🤩 + ### Use gradio-tools [gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging @@ -174,36 +188,6 @@ gradio_prompt_generator_tool = StableDiffusionPromptGeneratorTool() prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool) ``` -Now you can use it just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit`. - -```python -image_generation_tool = load_tool('huggingface-tools/text-to-image') -agent = CodeAgent(tools=[prompt_generator_tool, image_generation_tool], llm_engine=llm_engine) - -agent.run( - "Improve this prompt, then generate an image of it.", prompt='A rabbit wearing a space suit' -) -``` - -The model adequately leverages the tool: -```text -======== New task ======== -Improve this prompt, then generate an image of it. -You have been provided with these initial arguments: {'prompt': 'A rabbit wearing a space suit'}. -==== Agent is executing the code below: -improved_prompt = StableDiffusionPromptGenerator(query=prompt) -while improved_prompt == "QUEUE_FULL": - improved_prompt = StableDiffusionPromptGenerator(query=prompt) -print(f"The improved prompt is {improved_prompt}.") -image = image_generator(prompt=improved_prompt) -==== -``` - -Before finally generating the image: - - - - > [!WARNING] > gradio-tools require *textual* inputs and outputs even when working with different modalities like image and audio objects. Image and audio inputs and outputs are currently incompatible. From e9ea3ba4a598aea40e899f81cf27ed66b6c8719d Mon Sep 17 00:00:00 2001 From: Aymeric Date: Mon, 4 Nov 2024 17:33:23 +0100 Subject: [PATCH 3/3] Fix formatting --- src/transformers/agents/tools.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index 51f7c0a4bc68..994e1bdd817b 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -99,8 +99,10 @@ def new_init(self, *args, **kwargs): cls.__init__ = new_init return cls + CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"} + class Tool: """ A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the @@ -135,7 +137,6 @@ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) validate_after_init(cls, do_validate_forward=False) - def validate_arguments(self, do_validate_forward: bool = True): required_attributes = { "description": str, @@ -151,7 +152,9 @@ def validate_arguments(self, do_validate_forward: bool = True): raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.") for input_name, input_content in self.inputs.items(): assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary." - assert "type" in input_content and "description" in input_content, f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." + assert ( + "type" in input_content and "description" in input_content + ), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." if input_content["type"] not in authorized_types: raise Exception( f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}." @@ -409,7 +412,7 @@ def push_to_hub( create_pr=create_pr, repo_type="space", ) - + @staticmethod def from_space(space_id, name, description): """ @@ -439,9 +442,7 @@ def __init__(self, space_id, name, description): self.client = Client(space_id) self.name = name self.description = description - space_description = self.client.view_api(return_format="dict")[ - "named_endpoints" - ] + space_description = self.client.view_api(return_format="dict")["named_endpoints"] route = list(space_description.keys())[0] space_description_route = space_description[route] self.inputs = {} @@ -460,10 +461,9 @@ def __init__(self, space_id, name, description): self.output_type = "any" def forward(self, *args, **kwargs): - return self.client.predict(*args, **kwargs)[0] # Usually the first output is the result - - return SpaceToolWrapper(space_id, name, description) + return self.client.predict(*args, **kwargs)[0] # Usually the first output is the result + return SpaceToolWrapper(space_id, name, description) @staticmethod def from_gradio(gradio_tool): @@ -479,7 +479,9 @@ def __init__(self, _gradio_tool): self.output_type = "string" self._gradio_tool = _gradio_tool func_args = list(inspect.signature(_gradio_tool.run).parameters.items()) - self.inputs = {key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args} + self.inputs = { + key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args + } self.forward = self._gradio_tool.run return GradioToolWrapper(gradio_tool)