From a78bcedfd18271a66b38b1520d15a47009bebf7e Mon Sep 17 00:00:00 2001 From: John Li Date: Mon, 6 Jan 2025 12:09:15 -0800 Subject: [PATCH] fix types for python 3.8 typing.Dict/typing.List -> dict/list --- docs/generate_api_reference.py | 12 +++---- src/fastapi_poe/base.py | 34 ++++++++------------ src/fastapi_poe/client.py | 59 +++++++++++++++++----------------- src/fastapi_poe/types.py | 38 +++++++++++----------- 4 files changed, 69 insertions(+), 74 deletions(-) diff --git a/docs/generate_api_reference.py b/docs/generate_api_reference.py index 692a730..af0fd80 100644 --- a/docs/generate_api_reference.py +++ b/docs/generate_api_reference.py @@ -13,7 +13,7 @@ import sys import types from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Optional, Union sys.path.append("../src") import fastapi_poe @@ -32,7 +32,7 @@ class DocumentationData: name: str docstring: Optional[str] data_type: str - children: List = field(default_factory=lambda: []) + children: list = field(default_factory=lambda: []) def _unwrap_func(func_obj: Union[staticmethod, Callable]) -> Callable: @@ -43,8 +43,8 @@ def _unwrap_func(func_obj: Union[staticmethod, Callable]) -> Callable: def get_documentation_data( - *, module: types.ModuleType, documented_items: List[str] -) -> Dict[str, DocumentationData]: + *, module: types.ModuleType, documented_items: list[str] +) -> dict[str, DocumentationData]: data_dict = {} for name, obj in inspect.getmembers(module): if ( @@ -75,8 +75,8 @@ def get_documentation_data( def generate_documentation( *, - data_dict: Dict[str, DocumentationData], - documented_items: List[str], + data_dict: dict[str, DocumentationData], + documented_items: list[str], output_filename: str, ) -> None: # reset the file first diff --git a/src/fastapi_poe/base.py b/src/fastapi_poe/base.py index 5ff987e..9dd0264 100644 --- a/src/fastapi_poe/base.py +++ b/src/fastapi_poe/base.py @@ -7,18 +7,9 @@ import sys import warnings from collections import defaultdict +from collections.abc import AsyncIterable, Awaitable, Sequence from dataclasses import dataclass -from typing import ( - AsyncIterable, - Awaitable, - BinaryIO, - Callable, - Dict, - List, - Optional, - Sequence, - Union, -) +from typing import BinaryIO, Callable, Optional, Union import httpx import httpx_sse @@ -656,7 +647,7 @@ def make_prompt_author_role_alternated( async def capture_cost( self, request: QueryRequest, - amounts: Union[List[CostItem], CostItem], + amounts: Union[list[CostItem], CostItem], base_url: str = "https://api.poe.com/", ) -> None: """ @@ -666,7 +657,7 @@ async def capture_cost( #### Parameters: - `request` (`QueryRequest`): The currently handlded QueryRequest object. - - `amounts` (`Union[List[CostItem], CostItem]`): The to be captured amounts. + - `amounts` (`Union[list[CostItem], CostItem]`): The to be captured amounts. """ @@ -690,7 +681,7 @@ async def capture_cost( async def authorize_cost( self, request: QueryRequest, - amounts: Union[List[CostItem], CostItem], + amounts: Union[list[CostItem], CostItem], base_url: str = "https://api.poe.com/", ) -> None: """ @@ -700,7 +691,7 @@ async def authorize_cost( #### Parameters: - `request` (`QueryRequest`): The currently handlded QueryRequest object. - - `amounts` (`Union[List[CostItem], CostItem]`): The to be authorized amounts. + - `amounts` (`Union[list[CostItem], CostItem]`): The to be authorized amounts. """ @@ -722,15 +713,18 @@ async def authorize_cost( raise InsufficientFundError() async def _cost_requests_inner( - self, amounts: Union[List[CostItem], CostItem], access_key: str, url: str + self, amounts: Union[list[CostItem], CostItem], access_key: str, url: str ) -> bool: amounts = [amounts] if isinstance(amounts, CostItem) else amounts amounts_dicts = [amount.model_dump() for amount in amounts] data = {"amounts": amounts_dicts, "access_key": access_key} try: - async with httpx.AsyncClient(timeout=300) as client, httpx_sse.aconnect_sse( - client, method="POST", url=url, json=data - ) as event_source: + async with ( + httpx.AsyncClient(timeout=300) as client, + httpx_sse.aconnect_sse( + client, method="POST", url=url, json=data + ) as event_source, + ): if event_source.response.status_code != 200: error_pieces = [ json.loads(event.data).get("message", "") @@ -799,7 +793,7 @@ def error_event( allow_retry: bool = True, error_type: Optional[str] = None, ) -> ServerSentEvent: - data: Dict[str, Union[bool, str]] = {"allow_retry": allow_retry} + data: dict[str, Union[bool, str]] = {"allow_retry": allow_retry} if text is not None: data["text"] = text if raw_response is not None: diff --git a/src/fastapi_poe/client.py b/src/fastapi_poe/client.py index 3250538..62d7023 100644 --- a/src/fastapi_poe/client.py +++ b/src/fastapi_poe/client.py @@ -10,8 +10,9 @@ import inspect import json import warnings +from collections.abc import AsyncGenerator from dataclasses import dataclass, field -from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, cast +from typing import Any, Callable, Optional, cast import httpx import httpx_sse @@ -66,14 +67,14 @@ class _BotContext: on_error: Optional[ErrorHandler] = field(default=None, repr=False) @property - def headers(self) -> Dict[str, str]: + def headers(self) -> dict[str, str]: headers = {"Accept": "application/json"} if self.api_key is not None: headers["Authorization"] = f"Bearer {self.api_key}" return headers async def report_error( - self, message: str, metadata: Optional[Dict[str, Any]] = None + self, message: str, metadata: Optional[dict[str, Any]] = None ) -> None: """Report an error to the bot server.""" if self.on_error is not None: @@ -148,11 +149,11 @@ async def perform_query_request( self, *, request: QueryRequest, - tools: Optional[List[ToolDefinition]], - tool_calls: Optional[List[ToolCallDefinition]], - tool_results: Optional[List[ToolResultDefinition]], + tools: Optional[list[ToolDefinition]], + tool_calls: Optional[list[ToolCallDefinition]], + tool_results: Optional[list[ToolResultDefinition]], ) -> AsyncGenerator[BotMessage, None]: - chunks: List[str] = [] + chunks: list[str] = [] message_id = request.message_id event_count = 0 error_reported = False @@ -291,7 +292,7 @@ async def _get_single_json_field( async def _load_json_dict( self, data: str, context: str, message_id: Identifier - ) -> Dict[str, object]: + ) -> dict[str, object]: try: parsed = json.loads(data) except json.JSONDecodeError: @@ -307,7 +308,7 @@ async def _load_json_dict( {"data": data, "message_id": message_id}, ) raise BotError(f"Expected JSON dict in {context!r} event") - return cast(Dict[str, object], parsed) + return cast(dict[str, object], parsed) def _default_error_handler(e: Exception, msg: str) -> None: @@ -319,8 +320,8 @@ async def stream_request( bot_name: str, api_key: str = "", *, - tools: Optional[List[ToolDefinition]] = None, - tool_executables: Optional[List[Callable]] = None, + tools: Optional[list[ToolDefinition]] = None, + tool_executables: Optional[list[Callable]] = None, access_key: str = "", access_key_deprecation_warning_stacklevel: int = 2, session: Optional[httpx.AsyncClient] = None, @@ -342,9 +343,9 @@ async def stream_request( - `api_key` (`str = ""`): Your Poe API key, available at poe.com/api_key. You will need this in case you are trying to use this function from a script/shell. Note that if an `api_key` is provided, compute points will be charged on the account corresponding to the `api_key`. - - tools: (`Optional[List[ToolDefinition]] = None`): An list of ToolDefinition objects describing + - tools: (`Optional[list[ToolDefinition]] = None`): An list of ToolDefinition objects describing the functions you have. This is used for OpenAI function calling. - - tool_executables: (`Optional[List[Callable]] = None`): An list of functions corresponding + - tool_executables: (`Optional[list[Callable]] = None`): An list of functions corresponding to the ToolDefinitions. This is used for OpenAI function calling. """ @@ -387,8 +388,8 @@ async def stream_request( async def _get_tool_results( - tool_executables: List[Callable], tool_calls: List[ToolCallDefinition] -) -> List[ToolResultDefinition]: + tool_executables: list[Callable], tool_calls: list[ToolCallDefinition] +) -> list[ToolResultDefinition]: tool_executables_dict = { executable.__name__: executable for executable in tool_executables } @@ -418,7 +419,7 @@ async def _get_tool_calls( bot_name: str, api_key: str = "", *, - tools: List[ToolDefinition], + tools: list[ToolDefinition], access_key: str = "", access_key_deprecation_warning_stacklevel: int = 2, session: Optional[httpx.AsyncClient] = None, @@ -426,8 +427,8 @@ async def _get_tool_calls( num_tries: int = 2, retry_sleep_time: float = 0.5, base_url: str = "https://api.poe.com/bot/", -) -> List[ToolCallDefinition]: - tool_call_object_dict: Dict[int, Dict[str, Any]] = {} +) -> list[ToolCallDefinition]: + tool_call_object_dict: dict[int, dict[str, Any]] = {} async for message in stream_request_base( request=request, bot_name=bot_name, @@ -477,9 +478,9 @@ async def stream_request_base( bot_name: str, api_key: str = "", *, - tools: Optional[List[ToolDefinition]] = None, - tool_calls: Optional[List[ToolCallDefinition]] = None, - tool_results: Optional[List[ToolResultDefinition]] = None, + tools: Optional[list[ToolDefinition]] = None, + tool_calls: Optional[list[ToolCallDefinition]] = None, + tool_results: Optional[list[ToolResultDefinition]] = None, access_key: str = "", access_key_deprecation_warning_stacklevel: int = 2, session: Optional[httpx.AsyncClient] = None, @@ -535,16 +536,16 @@ async def stream_request_base( def get_bot_response( - messages: List[ProtocolMessage], + messages: list[ProtocolMessage], bot_name: str, api_key: str, *, - tools: Optional[List[ToolDefinition]] = None, - tool_executables: Optional[List[Callable]] = None, + tools: Optional[list[ToolDefinition]] = None, + tool_executables: Optional[list[Callable]] = None, temperature: Optional[float] = None, skip_system_prompt: Optional[bool] = None, - logit_bias: Optional[Dict[str, float]] = None, - stop_sequences: Optional[List[str]] = None, + logit_bias: Optional[dict[str, float]] = None, + stop_sequences: Optional[list[str]] = None, base_url: str = "https://api.poe.com/bot/", session: Optional[httpx.AsyncClient] = None, ) -> AsyncGenerator[BotMessage, None]: @@ -552,7 +553,7 @@ def get_bot_response( Use this function to invoke another Poe bot from your shell. #### Parameters: - - `messages` (`List[ProtocolMessage]`): A list of messages representing your conversation. + - `messages` (`list[ProtocolMessage]`): A list of messages representing your conversation. - `bot_name` (`str`): The bot that you want to invoke. - `api_key` (`str`): Your Poe API key. This is available at: [poe.com/api_key](https://poe.com/api_key) @@ -614,7 +615,7 @@ async def get_final_response( provided, compute points will be charged on the account corresponding to the `api_key`. """ - chunks: List[str] = [] + chunks: list[str] = [] async for message in stream_request( request, bot_name, @@ -643,7 +644,7 @@ def sync_bot_settings( bot_name: str, access_key: str = "", *, - settings: Optional[Dict[str, Any]] = None, + settings: Optional[dict[str, Any]] = None, base_url: str = "https://api.poe.com/bot/", ) -> None: """Fetch settings from the running bot server, and then sync them with Poe.""" diff --git a/src/fastapi_poe/types.py b/src/fastapi_poe/types.py index c4a8b14..0039f06 100644 --- a/src/fastapi_poe/types.py +++ b/src/fastapi_poe/types.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Optional from fastapi import Request from pydantic import BaseModel, ConfigDict, Field @@ -67,8 +67,8 @@ class ProtocolMessage(BaseModel): - `content_type` (`ContentType="text/markdown"`) - `timestamp` (`int = 0`) - `message_id` (`str = ""`) - - `feedback` (`List[MessageFeedback] = []`) - - `attachments` (`List[Attachment] = []`) + - `feedback` (`list[MessageFeedback] = []`) + - `attachments` (`list[Attachment] = []`) """ @@ -78,8 +78,8 @@ class ProtocolMessage(BaseModel): content_type: ContentType = "text/markdown" timestamp: int = 0 message_id: str = "" - feedback: List[MessageFeedback] = Field(default_factory=list) - attachments: List[Attachment] = Field(default_factory=list) + feedback: list[MessageFeedback] = Field(default_factory=list) + attachments: list[Attachment] = Field(default_factory=list) class RequestContext(BaseModel): @@ -103,7 +103,7 @@ class QueryRequest(BaseRequest): Request parameters for a query request. #### Fields: - - `query` (`List[ProtocolMessage]`): list of message representing the current state of the chat. + - `query` (`list[ProtocolMessage]`): list of message representing the current state of the chat. - `user_id` (`Identifier`): an anonymized identifier representing a user. This is persistent for subsequent requests from that user. - `conversation_id` (`Identifier`): an identifier representing a chat. This is @@ -113,14 +113,14 @@ class QueryRequest(BaseRequest): on Poe. - `temperature` (`float | None = None`): Temperature input to be used for model inference. - `skip_system_prompt` (`bool = False`): Whether to use any system prompting or not. - - `logit_bias` (`Dict[str, float] = {}`) - - `stop_sequences` (`List[str] = []`) + - `logit_bias` (`dict[str, float] = {}`) + - `stop_sequences` (`list[str] = []`) - `language_code` (`str = "en"`): BCP 47 language code of the user's client. - `bot_query_id` (`str = ""`): an identifier representing a bot query. """ - query: List[ProtocolMessage] + query: list[ProtocolMessage] user_id: Identifier conversation_id: Identifier message_id: Identifier @@ -129,8 +129,8 @@ class QueryRequest(BaseRequest): access_key: str = "" temperature: Optional[float] = None skip_system_prompt: bool = False - logit_bias: Dict[str, float] = {} - stop_sequences: List[str] = [] + logit_bias: dict[str, float] = {} + stop_sequences: list[str] = [] language_code: str = "en" bot_query_id: Identifier = "" @@ -186,12 +186,12 @@ class ReportErrorRequest(BaseRequest): Request parameters for a report_error request. #### Fields: - `message` (`str`) - - `metadata` (`Dict[str, Any]`) + - `metadata` (`dict[str, Any]`) """ message: str - metadata: Dict[str, Any] + metadata: dict[str, Any] class SettingsResponse(BaseModel): @@ -199,7 +199,7 @@ class SettingsResponse(BaseModel): An object representing your bot's response to a settings object. #### Fields: - - `server_bot_dependencies` (`Dict[str, int] = {}`): Information about other bots that your bot + - `server_bot_dependencies` (`dict[str, int] = {}`): Information about other bots that your bot uses. This is used to facilitate the Bot Query API. - `allow_attachments` (`bool = False`): Whether to allow users to upload attachments to your bot. @@ -222,7 +222,7 @@ class SettingsResponse(BaseModel): context_clear_window_secs: Optional[int] = None # deprecated allow_user_context_clear: Optional[bool] = None # deprecated - server_bot_dependencies: Dict[str, int] = Field(default_factory=dict) + server_bot_dependencies: dict[str, int] = Field(default_factory=dict) allow_attachments: Optional[bool] = None introduction_message: Optional[str] = None expand_text_attachments: Optional[bool] = None @@ -247,7 +247,7 @@ class PartialResponse(BaseModel): - `text` (`str`): The actual text you want to display to the user. Note that this should solely be the text in the next token since Poe will automatically concatenate all tokens before displaying the response to the user. - - `data` (`Optional[Dict[str, Any]]`): Used to send arbitrary json data to Poe. This is + - `data` (`Optional[dict[str, Any]]`): Used to send arbitrary json data to Poe. This is currently only used for OpenAI function calling. - `is_suggested_reply` (`bool = False`): Setting this to true will create a suggested reply with the provided text value. @@ -269,7 +269,7 @@ class PartialResponse(BaseModel): """ - data: Optional[Dict[str, Any]] = None + data: Optional[dict[str, Any]] = None """Used when a bot returns the json event.""" raw_response: object = None @@ -338,8 +338,8 @@ class ToolDefinition(BaseModel): class FunctionDefinition(BaseModel): class ParametersDefinition(BaseModel): type: str - properties: Dict[str, object] - required: Optional[List[str]] = None + properties: dict[str, object] + required: Optional[list[str]] = None name: str description: str