Skip to content

Commit

Permalink
fix bug with openai tool calls (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohntheLi authored Jan 6, 2025
1 parent 1eb5e67 commit f008a93
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 87 deletions.
8 changes: 1 addition & 7 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
exclude: # Apple Silicon ARM64 does not support Python < v3.8
- python-version: "3.7"
os: macos-latest
include: # So run those legacy versions on Intel CPUs
- python-version: "3.7"
os: macos-13
python-version: ["3.9", "3.10", "3.11", "3.12"]
fail-fast: false
steps:
- uses: actions/checkout@v3
Expand Down
12 changes: 6 additions & 6 deletions docs/generate_api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import sys
import types
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Optional, Union

sys.path.append("../src")
import fastapi_poe
Expand All @@ -32,7 +32,7 @@ class DocumentationData:
name: str
docstring: Optional[str]
data_type: str
children: List = field(default_factory=lambda: [])
children: list = field(default_factory=lambda: [])


def _unwrap_func(func_obj: Union[staticmethod, Callable]) -> Callable:
Expand All @@ -43,8 +43,8 @@ def _unwrap_func(func_obj: Union[staticmethod, Callable]) -> Callable:


def get_documentation_data(
*, module: types.ModuleType, documented_items: List[str]
) -> Dict[str, DocumentationData]:
*, module: types.ModuleType, documented_items: list[str]
) -> dict[str, DocumentationData]:
data_dict = {}
for name, obj in inspect.getmembers(module):
if (
Expand Down Expand Up @@ -75,8 +75,8 @@ def get_documentation_data(

def generate_documentation(
*,
data_dict: Dict[str, DocumentationData],
documented_items: List[str],
data_dict: dict[str, DocumentationData],
documented_items: list[str],
output_filename: str,
) -> None:
# reset the file first
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ authors = [
]
description = "A demonstration of the Poe protocol using FastAPI"
readme = "README.md"
requires-python = ">=3.7"
requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
Expand All @@ -34,10 +34,10 @@ dependencies = [
"Homepage" = "https://creator.poe.com/"

[tool.pyright]
pythonVersion = "3.7"
pythonVersion = "3.9"

[tool.black]
target-version = ['py37']
target-version = ['py39']
skip-magic-trailing-comma = true

[tool.ruff]
Expand Down Expand Up @@ -78,4 +78,4 @@ lint.ignore = [
]

line-length = 100
target-version = "py37"
target-version = "py39"
37 changes: 16 additions & 21 deletions src/fastapi_poe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,18 @@
import sys
import warnings
from collections import defaultdict
from collections.abc import AsyncIterable, Awaitable, Sequence
from dataclasses import dataclass
from typing import (
AsyncIterable,
Awaitable,
BinaryIO,
Callable,
Dict,
List,
Optional,
Sequence,
Union,
)
from typing import BinaryIO, Callable, Optional, Union

import httpx
import httpx_sse
from fastapi import Depends, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
from sse_starlette.event import ServerSentEvent
from sse_starlette.sse import EventSourceResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import Message
from typing_extensions import deprecated, overload
Expand Down Expand Up @@ -655,7 +647,7 @@ def make_prompt_author_role_alternated(
async def capture_cost(
self,
request: QueryRequest,
amounts: Union[List[CostItem], CostItem],
amounts: Union[list[CostItem], CostItem],
base_url: str = "https://api.poe.com/",
) -> None:
"""
Expand All @@ -665,7 +657,7 @@ async def capture_cost(
#### Parameters:
- `request` (`QueryRequest`): The currently handlded QueryRequest object.
- `amounts` (`Union[List[CostItem], CostItem]`): The to be captured amounts.
- `amounts` (`Union[list[CostItem], CostItem]`): The to be captured amounts.
"""

Expand All @@ -689,7 +681,7 @@ async def capture_cost(
async def authorize_cost(
self,
request: QueryRequest,
amounts: Union[List[CostItem], CostItem],
amounts: Union[list[CostItem], CostItem],
base_url: str = "https://api.poe.com/",
) -> None:
"""
Expand All @@ -699,7 +691,7 @@ async def authorize_cost(
#### Parameters:
- `request` (`QueryRequest`): The currently handlded QueryRequest object.
- `amounts` (`Union[List[CostItem], CostItem]`): The to be authorized amounts.
- `amounts` (`Union[list[CostItem], CostItem]`): The to be authorized amounts.
"""

Expand All @@ -721,15 +713,18 @@ async def authorize_cost(
raise InsufficientFundError()

async def _cost_requests_inner(
self, amounts: Union[List[CostItem], CostItem], access_key: str, url: str
self, amounts: Union[list[CostItem], CostItem], access_key: str, url: str
) -> bool:
amounts = [amounts] if isinstance(amounts, CostItem) else amounts
amounts_dicts = [amount.model_dump() for amount in amounts]
data = {"amounts": amounts_dicts, "access_key": access_key}
try:
async with httpx.AsyncClient(timeout=300) as client, httpx_sse.aconnect_sse(
client, method="POST", url=url, json=data
) as event_source:
async with (
httpx.AsyncClient(timeout=300) as client,
httpx_sse.aconnect_sse(
client, method="POST", url=url, json=data
) as event_source,
):
if event_source.response.status_code != 200:
error_pieces = [
json.loads(event.data).get("message", "")
Expand Down Expand Up @@ -798,7 +793,7 @@ def error_event(
allow_retry: bool = True,
error_type: Optional[str] = None,
) -> ServerSentEvent:
data: Dict[str, Union[bool, str]] = {"allow_retry": allow_retry}
data: dict[str, Union[bool, str]] = {"allow_retry": allow_retry}
if text is not None:
data["text"] = text
if raw_response is not None:
Expand Down
65 changes: 35 additions & 30 deletions src/fastapi_poe/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import inspect
import json
import warnings
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, cast
from typing import Any, Callable, Optional, cast

import httpx
import httpx_sse
Expand Down Expand Up @@ -66,14 +67,14 @@ class _BotContext:
on_error: Optional[ErrorHandler] = field(default=None, repr=False)

@property
def headers(self) -> Dict[str, str]:
def headers(self) -> dict[str, str]:
headers = {"Accept": "application/json"}
if self.api_key is not None:
headers["Authorization"] = f"Bearer {self.api_key}"
return headers

async def report_error(
self, message: str, metadata: Optional[Dict[str, Any]] = None
self, message: str, metadata: Optional[dict[str, Any]] = None
) -> None:
"""Report an error to the bot server."""
if self.on_error is not None:
Expand Down Expand Up @@ -148,11 +149,11 @@ async def perform_query_request(
self,
*,
request: QueryRequest,
tools: Optional[List[ToolDefinition]],
tool_calls: Optional[List[ToolCallDefinition]],
tool_results: Optional[List[ToolResultDefinition]],
tools: Optional[list[ToolDefinition]],
tool_calls: Optional[list[ToolCallDefinition]],
tool_results: Optional[list[ToolResultDefinition]],
) -> AsyncGenerator[BotMessage, None]:
chunks: List[str] = []
chunks: list[str] = []
message_id = request.message_id
event_count = 0
error_reported = False
Expand Down Expand Up @@ -291,7 +292,7 @@ async def _get_single_json_field(

async def _load_json_dict(
self, data: str, context: str, message_id: Identifier
) -> Dict[str, object]:
) -> dict[str, object]:
try:
parsed = json.loads(data)
except json.JSONDecodeError:
Expand All @@ -307,7 +308,7 @@ async def _load_json_dict(
{"data": data, "message_id": message_id},
)
raise BotError(f"Expected JSON dict in {context!r} event")
return cast(Dict[str, object], parsed)
return cast(dict[str, object], parsed)


def _default_error_handler(e: Exception, msg: str) -> None:
Expand All @@ -319,8 +320,8 @@ async def stream_request(
bot_name: str,
api_key: str = "",
*,
tools: Optional[List[ToolDefinition]] = None,
tool_executables: Optional[List[Callable]] = None,
tools: Optional[list[ToolDefinition]] = None,
tool_executables: Optional[list[Callable]] = None,
access_key: str = "",
access_key_deprecation_warning_stacklevel: int = 2,
session: Optional[httpx.AsyncClient] = None,
Expand All @@ -342,9 +343,9 @@ async def stream_request(
- `api_key` (`str = ""`): Your Poe API key, available at poe.com/api_key. You will need
this in case you are trying to use this function from a script/shell. Note that if an `api_key`
is provided, compute points will be charged on the account corresponding to the `api_key`.
- tools: (`Optional[List[ToolDefinition]] = None`): An list of ToolDefinition objects describing
- tools: (`Optional[list[ToolDefinition]] = None`): An list of ToolDefinition objects describing
the functions you have. This is used for OpenAI function calling.
- tool_executables: (`Optional[List[Callable]] = None`): An list of functions corresponding
- tool_executables: (`Optional[list[Callable]] = None`): An list of functions corresponding
to the ToolDefinitions. This is used for OpenAI function calling.
"""
Expand Down Expand Up @@ -387,8 +388,8 @@ async def stream_request(


async def _get_tool_results(
tool_executables: List[Callable], tool_calls: List[ToolCallDefinition]
) -> List[ToolResultDefinition]:
tool_executables: list[Callable], tool_calls: list[ToolCallDefinition]
) -> list[ToolResultDefinition]:
tool_executables_dict = {
executable.__name__: executable for executable in tool_executables
}
Expand Down Expand Up @@ -418,16 +419,16 @@ async def _get_tool_calls(
bot_name: str,
api_key: str = "",
*,
tools: List[ToolDefinition],
tools: list[ToolDefinition],
access_key: str = "",
access_key_deprecation_warning_stacklevel: int = 2,
session: Optional[httpx.AsyncClient] = None,
on_error: ErrorHandler = _default_error_handler,
num_tries: int = 2,
retry_sleep_time: float = 0.5,
base_url: str = "https://api.poe.com/bot/",
) -> List[ToolCallDefinition]:
tool_call_object_dict: Dict[int, Dict[str, Any]] = {}
) -> list[ToolCallDefinition]:
tool_call_object_dict: dict[int, dict[str, Any]] = {}
async for message in stream_request_base(
request=request,
bot_name=bot_name,
Expand All @@ -441,7 +442,11 @@ async def _get_tool_calls(
retry_sleep_time=retry_sleep_time,
base_url=base_url,
):
if message.data is not None:
if (
message.data is not None
and "choices" in message.data
and message.data["choices"]
):
finish_reason = message.data["choices"][0]["finish_reason"]
if finish_reason is None:
try:
Expand Down Expand Up @@ -473,9 +478,9 @@ async def stream_request_base(
bot_name: str,
api_key: str = "",
*,
tools: Optional[List[ToolDefinition]] = None,
tool_calls: Optional[List[ToolCallDefinition]] = None,
tool_results: Optional[List[ToolResultDefinition]] = None,
tools: Optional[list[ToolDefinition]] = None,
tool_calls: Optional[list[ToolCallDefinition]] = None,
tool_results: Optional[list[ToolResultDefinition]] = None,
access_key: str = "",
access_key_deprecation_warning_stacklevel: int = 2,
session: Optional[httpx.AsyncClient] = None,
Expand Down Expand Up @@ -531,24 +536,24 @@ async def stream_request_base(


def get_bot_response(
messages: List[ProtocolMessage],
messages: list[ProtocolMessage],
bot_name: str,
api_key: str,
*,
tools: Optional[List[ToolDefinition]] = None,
tool_executables: Optional[List[Callable]] = None,
tools: Optional[list[ToolDefinition]] = None,
tool_executables: Optional[list[Callable]] = None,
temperature: Optional[float] = None,
skip_system_prompt: Optional[bool] = None,
logit_bias: Optional[Dict[str, float]] = None,
stop_sequences: Optional[List[str]] = None,
logit_bias: Optional[dict[str, float]] = None,
stop_sequences: Optional[list[str]] = None,
base_url: str = "https://api.poe.com/bot/",
session: Optional[httpx.AsyncClient] = None,
) -> AsyncGenerator[BotMessage, None]:
"""
Use this function to invoke another Poe bot from your shell.
#### Parameters:
- `messages` (`List[ProtocolMessage]`): A list of messages representing your conversation.
- `messages` (`list[ProtocolMessage]`): A list of messages representing your conversation.
- `bot_name` (`str`): The bot that you want to invoke.
- `api_key` (`str`): Your Poe API key. This is available at: [poe.com/api_key](https://poe.com/api_key)
Expand Down Expand Up @@ -610,7 +615,7 @@ async def get_final_response(
provided, compute points will be charged on the account corresponding to the `api_key`.
"""
chunks: List[str] = []
chunks: list[str] = []
async for message in stream_request(
request,
bot_name,
Expand Down Expand Up @@ -639,7 +644,7 @@ def sync_bot_settings(
bot_name: str,
access_key: str = "",
*,
settings: Optional[Dict[str, Any]] = None,
settings: Optional[dict[str, Any]] = None,
base_url: str = "https://api.poe.com/bot/",
) -> None:
"""Fetch settings from the running bot server, and then sync them with Poe."""
Expand Down
Loading

0 comments on commit f008a93

Please sign in to comment.