diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 446aa4f2508e..73b7186d25a3 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -22,7 +22,7 @@ from .. import is_torch_available from ..utils import logging as transformers_logging from ..utils.import_utils import is_pygments_available -from .agent_types import AgentAudio, AgentImage, AgentText +from .agent_types import AgentAudio, AgentImage from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools from .llm_engine import HfApiEngine, MessageRole from .prompts import ( diff --git a/src/transformers/agents/default_tools.py b/src/transformers/agents/default_tools.py index 10097fa86743..3946aa9f8735 100644 --- a/src/transformers/agents/default_tools.py +++ b/src/transformers/agents/default_tools.py @@ -25,7 +25,7 @@ from ..utils import is_offline_mode from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code -from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool, tool +from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool def custom_print(*args): diff --git a/src/transformers/agents/prompts.py b/src/transformers/agents/prompts.py index 266d0d78796f..7a84b1db44fa 100644 --- a/src/transformers/agents/prompts.py +++ b/src/transformers/agents/prompts.py @@ -199,7 +199,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): Action: { "action": "image_generator", - "action_input": {"prompt": ""A portrait of John Doe, a 55-year-old man living in Canada.""} + "action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} } Observation: "image.png" diff --git a/src/transformers/agents/search.py b/src/transformers/agents/search.py index 4889de084711..f50a7c6ab8f9 100644 --- a/src/transformers/agents/search.py +++ b/src/transformers/agents/search.py @@ -19,7 +19,7 @@ import requests from requests.exceptions import RequestException -from .tools import Tool, tool +from .tools import Tool class DuckDuckGoSearchTool(Tool): @@ -50,7 +50,7 @@ class VisitWebpageTool(Tool): } } output_type = "string" - + def forward(self, url: str) -> str: try: from markdownify import markdownify diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 1ca8e59d926d..74912ce30146 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -22,7 +22,7 @@ from packaging import version -from .import_utils import is_jinja_available, is_vision_available, is_torch_available +from .import_utils import is_jinja_available, is_torch_available, is_vision_available if is_jinja_available(): @@ -69,6 +69,7 @@ class DocstringParsingException(Exception): pass + def _get_json_schema_type(param_type: str) -> Dict[str, str]: type_mapping = { int: {"type": "integer"}, diff --git a/tests/agents/test_final_answer.py b/tests/agents/test_final_answer.py index 3bd7c0f124fc..ee744bbd7bd3 100644 --- a/tests/agents/test_final_answer.py +++ b/tests/agents/test_final_answer.py @@ -19,11 +19,12 @@ import numpy as np from PIL import Image -from transformers import is_torch_available, load_tool +from transformers import is_torch_available from transformers.agents.agent_types import AGENT_TYPE_MAPPING from transformers.agents.default_tools import FinalAnswerTool from transformers.testing_utils import get_tests_dir, require_torch + if is_torch_available(): import torch diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index 71dcbe3b641c..15e5ad7bb3a3 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -392,7 +392,7 @@ def test_if_conditions(self): if char.isalpha(): print('2')""" state = {} - result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) + evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) assert state["print_outputs"] == "2\n" def test_imports(self): @@ -470,7 +470,7 @@ def test_print_output(self): code = "print('Hello world!')\nprint('Ok no one cares')" state = {} result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) - assert result == None + assert result is None assert state["print_outputs"] == "Hello world!\nOk no one cares\n" # test print in function @@ -594,7 +594,7 @@ def method_that_raises(self): def test_print(self): code = "print(min([1, 2, 3]))" state = {} - result = evaluate_python_code(code, {"min": min, "print": print}, state=state) + evaluate_python_code(code, {"min": min, "print": print}, state=state) assert state["print_outputs"] == "1\n" def test_types_as_objects(self): diff --git a/tests/agents/test_tools_common.py b/tests/agents/test_tools_common.py index 521783a7b6c9..47770669c3eb 100644 --- a/tests/agents/test_tools_common.py +++ b/tests/agents/test_tools_common.py @@ -31,7 +31,7 @@ from PIL import Image -AUTHORIZED_TYPES = ["string", "number", "audio", "image", "any"] +AUTHORIZED_TYPES = ["string", "boolean", "integer", "number", "audio", "image", "any"] def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):