Skip to content

Commit

Permalink
feat(sdk): make ClientSession optional when initializing ToolboxClient (
Browse files Browse the repository at this point in the history
  • Loading branch information
anubhav756 authored Nov 12, 2024
1 parent a415f29 commit 26347b5
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 18 deletions.
8 changes: 8 additions & 0 deletions sdks/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ from toolbox_langchain_sdk import ToolboxClient
toolbox = ToolboxClient("http://localhost:5000")
```

> [!TIP]
> You can also pass your own `ClientSession` so that the `ToolboxClient` can
> reuse the same session.
> ```
> async with ClientSession() as session:
> client = ToolboxClient(http://localhost:5000, session)
> ```
## Load a toolset
You can load a toolset, a collection of related tools.
Expand Down
43 changes: 37 additions & 6 deletions sdks/langchain/src/toolbox_langchain_sdk/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Optional

from aiohttp import ClientSession
Expand All @@ -8,20 +9,47 @@


class ToolboxClient:
def __init__(self, url: str, session: ClientSession):
def __init__(self, url: str, session: Optional[ClientSession] = None):
"""
Initializes the ToolboxClient for the Toolbox service at the given URL.
Args:
url: The base URL of the Toolbox service.
session: The HTTP client session.
Default: None
"""
self._url: str = url
self._session = session
self._should_close_session: bool = session != None
self._session: ClientSession = session or ClientSession()

async def close(self) -> None:
"""
Close the Toolbox client and its tools.
"""
# We check whether _should_close_session is set or not since we do not
# want to close the session in case the user had passed their own
# ClientSession object, since then we expect the user to be owning its
# lifecycle.
if self._session and self._should_close_session:
await self._session.close()

def __del__(self):
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(self.close())
else:
loop.run_until_complete(self.close())
except Exception:
# We "pass" assuming that the exception is thrown because the event
# loop is no longer running, but at that point the Session should
# have been closed already anyway.
pass

async def _load_tool_manifest(self, tool_name: str) -> ManifestSchema:
"""
Fetches and parses the YAML manifest for the given tool from the Toolbox service.
Fetches and parses the YAML manifest for the given tool from the Toolbox
service.
Args:
tool_name: The name of the tool to load.
Expand All @@ -40,7 +68,8 @@ async def _load_toolset_manifest(
Args:
toolset_name: The name of the toolset to load.
Default: None. If not provided, then all the available tools are loaded.
Default: None. If not provided, then all the available tools are
loaded.
Returns:
The parsed Toolbox manifest.
Expand All @@ -52,7 +81,8 @@ def _generate_tool(
self, tool_name: str, manifest: ManifestSchema
) -> StructuredTool:
"""
Creates a StructuredTool object and a dynamically generated BaseModel for the given tool.
Creates a StructuredTool object and a dynamically generated BaseModel
for the given tool.
Args:
tool_name: The name of the tool to generate.
Expand Down Expand Up @@ -93,7 +123,8 @@ async def load_toolset(
self, toolset_name: Optional[str] = None
) -> list[StructuredTool]:
"""
Loads tools from the Toolbox service, optionally filtered by toolset name.
Loads tools from the Toolbox service, optionally filtered by toolset
name.
Args:
toolset_name: The name of the toolset to load.
Expand Down
2 changes: 1 addition & 1 deletion sdks/langchain/src/toolbox_langchain_sdk/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Type, Optional
from typing import Any, Optional, Type

import yaml
from aiohttp import ClientSession
Expand Down
3 changes: 2 additions & 1 deletion sdks/llamaindex/src/toolbox_llamaindex_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .client import ToolboxClient

# import utils

__all__ = ["ToolboxClient"]
__all__ = ["ToolboxClient"]
46 changes: 39 additions & 7 deletions sdks/llamaindex/src/toolbox_llamaindex_sdk/client.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,55 @@
import asyncio
from typing import Optional

from aiohttp import ClientSession
from llama_index.core.tools import FunctionTool
from pydantic import BaseModel

from .utils import ManifestSchema, _invoke_tool, _load_yaml, _schema_to_model


class ToolboxClient:
def __init__(self, url: str, session: ClientSession):
def __init__(self, url: str, session: Optional[ClientSession] = None):
"""
Initializes the ToolboxClient for the Toolbox service at the given URL.
Args:
url: The base URL of the Toolbox service.
session: The HTTP client session.
Default: None
"""
self._url: str = url
self._session = session
self._should_close_session: bool = session != None
self._session: ClientSession = session or ClientSession()

async def close(self) -> None:
"""
Close the Toolbox client and its tools.
"""
# We check whether _should_close_session is set or not since we do not
# want to close the session in case the user had passed their own
# ClientSession object, since then we expect the user to be owning its
# lifecycle.
if self._session and self._should_close_session:
await self._session.close()

def __del__(self):
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(self.close())
else:
loop.run_until_complete(self.close())
except Exception:
# We "pass" assuming that the exception is thrown because the event
# loop is no longer running, but at that point the Session should
# have been closed already anyway.
pass

async def _load_tool_manifest(self, tool_name: str) -> ManifestSchema:
"""
Fetches and parses the YAML manifest for the given tool from the Toolbox service.
Fetches and parses the YAML manifest for the given tool from the Toolbox
service.
Args:
tool_name: The name of the tool to load.
Expand All @@ -39,7 +68,8 @@ async def _load_toolset_manifest(
Args:
toolset_name: The name of the toolset to load.
Default: None. If not provided, then all the available tools are loaded.
Default: None. If not provided, then all the available tools are
loaded.
Returns:
The parsed Toolbox manifest.
Expand All @@ -49,7 +79,8 @@ async def _load_toolset_manifest(

def _generate_tool(self, tool_name: str, manifest: ManifestSchema) -> FunctionTool:
"""
Creates a FunctionTool object and a dynamically generated BaseModel for the given tool.
Creates a FunctionTool object and a dynamically generated BaseModel for
the given tool.
Args:
tool_name: The name of the tool to generate.
Expand All @@ -59,7 +90,7 @@ def _generate_tool(self, tool_name: str, manifest: ManifestSchema) -> FunctionTo
The generated tool.
"""
tool_schema = manifest.tools[tool_name]
tool_model = _schema_to_model(
tool_model: BaseModel = _schema_to_model(
model_name=tool_name, schema=tool_schema.parameters
)

Expand Down Expand Up @@ -90,7 +121,8 @@ async def load_toolset(
self, toolset_name: Optional[str] = None
) -> list[FunctionTool]:
"""
Loads tools from the Toolbox service, optionally filtered by toolset name.
Loads tools from the Toolbox service, optionally filtered by toolset
name.
Args:
toolset_name: The name of the toolset to load.
Expand Down
5 changes: 2 additions & 3 deletions sdks/llamaindex/src/toolbox_llamaindex_sdk/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Type, Optional
from typing import Any, Optional, Type

import yaml
from aiohttp import ClientSession
Expand Down Expand Up @@ -103,8 +103,7 @@ async def _invoke_tool(
url = f"{url}/api/tool/{tool_name}/invoke"
async with session.post(url, json=_convert_none_to_empty_string(data)) as response:
response.raise_for_status()
json_response = await response.json()
return json_response
return await response.json()


# TODO: Remove this temporary fix once optional fields are supported by Toolbox.
Expand Down

0 comments on commit 26347b5

Please sign in to comment.