diff --git a/sdks/langchain/README.md b/sdks/langchain/README.md index af86d0c06..09a8df5b2 100644 --- a/sdks/langchain/README.md +++ b/sdks/langchain/README.md @@ -40,6 +40,14 @@ from toolbox_langchain_sdk import ToolboxClient toolbox = ToolboxClient("http://localhost:5000") ``` +> [!TIP] +> You can also pass your own `ClientSession` so that the `ToolboxClient` can +> reuse the same session. +> ``` +> async with ClientSession() as session: +> client = ToolboxClient(http://localhost:5000, session) +> ``` + ## Load a toolset You can load a toolset, a collection of related tools. diff --git a/sdks/langchain/src/toolbox_langchain_sdk/client.py b/sdks/langchain/src/toolbox_langchain_sdk/client.py index f3d12f58d..9dd4eae91 100644 --- a/sdks/langchain/src/toolbox_langchain_sdk/client.py +++ b/sdks/langchain/src/toolbox_langchain_sdk/client.py @@ -1,3 +1,4 @@ +import asyncio from typing import Optional from aiohttp import ClientSession @@ -8,20 +9,47 @@ class ToolboxClient: - def __init__(self, url: str, session: ClientSession): + def __init__(self, url: str, session: Optional[ClientSession] = None): """ Initializes the ToolboxClient for the Toolbox service at the given URL. Args: url: The base URL of the Toolbox service. session: The HTTP client session. + Default: None """ self._url: str = url - self._session = session + self._should_close_session: bool = session != None + self._session: ClientSession = session or ClientSession() + + async def close(self) -> None: + """ + Close the Toolbox client and its tools. + """ + # We check whether _should_close_session is set or not since we do not + # want to close the session in case the user had passed their own + # ClientSession object, since then we expect the user to be owning its + # lifecycle. + if self._session and self._should_close_session: + await self._session.close() + + def __del__(self): + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(self.close()) + else: + loop.run_until_complete(self.close()) + except Exception: + # We "pass" assuming that the exception is thrown because the event + # loop is no longer running, but at that point the Session should + # have been closed already anyway. + pass async def _load_tool_manifest(self, tool_name: str) -> ManifestSchema: """ - Fetches and parses the YAML manifest for the given tool from the Toolbox service. + Fetches and parses the YAML manifest for the given tool from the Toolbox + service. Args: tool_name: The name of the tool to load. @@ -40,7 +68,8 @@ async def _load_toolset_manifest( Args: toolset_name: The name of the toolset to load. - Default: None. If not provided, then all the available tools are loaded. + Default: None. If not provided, then all the available tools are + loaded. Returns: The parsed Toolbox manifest. @@ -52,7 +81,8 @@ def _generate_tool( self, tool_name: str, manifest: ManifestSchema ) -> StructuredTool: """ - Creates a StructuredTool object and a dynamically generated BaseModel for the given tool. + Creates a StructuredTool object and a dynamically generated BaseModel + for the given tool. Args: tool_name: The name of the tool to generate. @@ -93,7 +123,8 @@ async def load_toolset( self, toolset_name: Optional[str] = None ) -> list[StructuredTool]: """ - Loads tools from the Toolbox service, optionally filtered by toolset name. + Loads tools from the Toolbox service, optionally filtered by toolset + name. Args: toolset_name: The name of the toolset to load. diff --git a/sdks/langchain/src/toolbox_langchain_sdk/utils.py b/sdks/langchain/src/toolbox_langchain_sdk/utils.py index cdaf88b5e..6d9d72615 100644 --- a/sdks/langchain/src/toolbox_langchain_sdk/utils.py +++ b/sdks/langchain/src/toolbox_langchain_sdk/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Type, Optional +from typing import Any, Optional, Type import yaml from aiohttp import ClientSession diff --git a/sdks/llamaindex/src/toolbox_llamaindex_sdk/__init__.py b/sdks/llamaindex/src/toolbox_llamaindex_sdk/__init__.py index 457cdc020..5bfbc3115 100644 --- a/sdks/llamaindex/src/toolbox_llamaindex_sdk/__init__.py +++ b/sdks/llamaindex/src/toolbox_llamaindex_sdk/__init__.py @@ -1,4 +1,5 @@ from .client import ToolboxClient + # import utils -__all__ = ["ToolboxClient"] \ No newline at end of file +__all__ = ["ToolboxClient"] diff --git a/sdks/llamaindex/src/toolbox_llamaindex_sdk/client.py b/sdks/llamaindex/src/toolbox_llamaindex_sdk/client.py index 012c83327..c153e6f26 100644 --- a/sdks/llamaindex/src/toolbox_llamaindex_sdk/client.py +++ b/sdks/llamaindex/src/toolbox_llamaindex_sdk/client.py @@ -1,26 +1,55 @@ +import asyncio from typing import Optional from aiohttp import ClientSession from llama_index.core.tools import FunctionTool +from pydantic import BaseModel from .utils import ManifestSchema, _invoke_tool, _load_yaml, _schema_to_model class ToolboxClient: - def __init__(self, url: str, session: ClientSession): + def __init__(self, url: str, session: Optional[ClientSession] = None): """ Initializes the ToolboxClient for the Toolbox service at the given URL. Args: url: The base URL of the Toolbox service. session: The HTTP client session. + Default: None """ self._url: str = url - self._session = session + self._should_close_session: bool = session != None + self._session: ClientSession = session or ClientSession() + + async def close(self) -> None: + """ + Close the Toolbox client and its tools. + """ + # We check whether _should_close_session is set or not since we do not + # want to close the session in case the user had passed their own + # ClientSession object, since then we expect the user to be owning its + # lifecycle. + if self._session and self._should_close_session: + await self._session.close() + + def __del__(self): + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(self.close()) + else: + loop.run_until_complete(self.close()) + except Exception: + # We "pass" assuming that the exception is thrown because the event + # loop is no longer running, but at that point the Session should + # have been closed already anyway. + pass async def _load_tool_manifest(self, tool_name: str) -> ManifestSchema: """ - Fetches and parses the YAML manifest for the given tool from the Toolbox service. + Fetches and parses the YAML manifest for the given tool from the Toolbox + service. Args: tool_name: The name of the tool to load. @@ -39,7 +68,8 @@ async def _load_toolset_manifest( Args: toolset_name: The name of the toolset to load. - Default: None. If not provided, then all the available tools are loaded. + Default: None. If not provided, then all the available tools are + loaded. Returns: The parsed Toolbox manifest. @@ -49,7 +79,8 @@ async def _load_toolset_manifest( def _generate_tool(self, tool_name: str, manifest: ManifestSchema) -> FunctionTool: """ - Creates a FunctionTool object and a dynamically generated BaseModel for the given tool. + Creates a FunctionTool object and a dynamically generated BaseModel for + the given tool. Args: tool_name: The name of the tool to generate. @@ -59,7 +90,7 @@ def _generate_tool(self, tool_name: str, manifest: ManifestSchema) -> FunctionTo The generated tool. """ tool_schema = manifest.tools[tool_name] - tool_model = _schema_to_model( + tool_model: BaseModel = _schema_to_model( model_name=tool_name, schema=tool_schema.parameters ) @@ -90,7 +121,8 @@ async def load_toolset( self, toolset_name: Optional[str] = None ) -> list[FunctionTool]: """ - Loads tools from the Toolbox service, optionally filtered by toolset name. + Loads tools from the Toolbox service, optionally filtered by toolset + name. Args: toolset_name: The name of the toolset to load. diff --git a/sdks/llamaindex/src/toolbox_llamaindex_sdk/utils.py b/sdks/llamaindex/src/toolbox_llamaindex_sdk/utils.py index 1dda638da..6d9d72615 100644 --- a/sdks/llamaindex/src/toolbox_llamaindex_sdk/utils.py +++ b/sdks/llamaindex/src/toolbox_llamaindex_sdk/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Type, Optional +from typing import Any, Optional, Type import yaml from aiohttp import ClientSession @@ -103,8 +103,7 @@ async def _invoke_tool( url = f"{url}/api/tool/{tool_name}/invoke" async with session.post(url, json=_convert_none_to_empty_string(data)) as response: response.raise_for_status() - json_response = await response.json() - return json_response + return await response.json() # TODO: Remove this temporary fix once optional fields are supported by Toolbox.