Skip to content

Commit

Permalink
chore(internal): loosen type var restrictions (#301)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot authored Jan 5, 2024
1 parent 8671297 commit 5e5e1e7
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 44 deletions.
41 changes: 19 additions & 22 deletions src/anthropic/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
Body,
Omit,
Query,
ModelT,
Headers,
Timeout,
NotGiven,
Expand All @@ -61,7 +60,6 @@
HttpxSendArgs,
AsyncTransport,
RequestOptions,
UnknownResponse,
ModelBuilderProtocol,
BinaryResponseContent,
)
Expand Down Expand Up @@ -142,7 +140,7 @@ def __init__(
self.params = params


class BasePage(GenericModel, Generic[ModelT]):
class BasePage(GenericModel, Generic[_T]):
"""
Defines the core interface for pagination.
Expand All @@ -155,7 +153,7 @@ class BasePage(GenericModel, Generic[ModelT]):
"""

_options: FinalRequestOptions = PrivateAttr()
_model: Type[ModelT] = PrivateAttr()
_model: Type[_T] = PrivateAttr()

def has_next_page(self) -> bool:
items = self._get_page_items()
Expand All @@ -166,7 +164,7 @@ def has_next_page(self) -> bool:
def next_page_info(self) -> Optional[PageInfo]:
...

def _get_page_items(self) -> Iterable[ModelT]: # type: ignore[empty-body]
def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body]
...

def _params_from_url(self, url: URL) -> httpx.QueryParams:
Expand All @@ -191,13 +189,13 @@ def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
raise ValueError("Unexpected PageInfo state")


class BaseSyncPage(BasePage[ModelT], Generic[ModelT]):
class BaseSyncPage(BasePage[_T], Generic[_T]):
_client: SyncAPIClient = pydantic.PrivateAttr()

def _set_private_attributes(
self,
client: SyncAPIClient,
model: Type[ModelT],
model: Type[_T],
options: FinalRequestOptions,
) -> None:
self._model = model
Expand All @@ -212,7 +210,7 @@ def _set_private_attributes(
# methods should continue to work as expected as there is an alternative method
# to cast a model to a dictionary, model.dict(), which is used internally
# by pydantic.
def __iter__(self) -> Iterator[ModelT]: # type: ignore
def __iter__(self) -> Iterator[_T]: # type: ignore
for page in self.iter_pages():
for item in page._get_page_items():
yield item
Expand All @@ -237,13 +235,13 @@ def get_next_page(self: SyncPageT) -> SyncPageT:
return self._client._request_api_list(self._model, page=self.__class__, options=options)


class AsyncPaginator(Generic[ModelT, AsyncPageT]):
class AsyncPaginator(Generic[_T, AsyncPageT]):
def __init__(
self,
client: AsyncAPIClient,
options: FinalRequestOptions,
page_cls: Type[AsyncPageT],
model: Type[ModelT],
model: Type[_T],
) -> None:
self._model = model
self._client = client
Expand All @@ -266,7 +264,7 @@ def _parser(resp: AsyncPageT) -> AsyncPageT:

return await self._client.request(self._page_cls, self._options)

async def __aiter__(self) -> AsyncIterator[ModelT]:
async def __aiter__(self) -> AsyncIterator[_T]:
# https://github.com/microsoft/pyright/issues/3464
page = cast(
AsyncPageT,
Expand All @@ -276,20 +274,20 @@ async def __aiter__(self) -> AsyncIterator[ModelT]:
yield item


class BaseAsyncPage(BasePage[ModelT], Generic[ModelT]):
class BaseAsyncPage(BasePage[_T], Generic[_T]):
_client: AsyncAPIClient = pydantic.PrivateAttr()

def _set_private_attributes(
self,
model: Type[ModelT],
model: Type[_T],
client: AsyncAPIClient,
options: FinalRequestOptions,
) -> None:
self._model = model
self._client = client
self._options = options

async def __aiter__(self) -> AsyncIterator[ModelT]:
async def __aiter__(self) -> AsyncIterator[_T]:
async for page in self.iter_pages():
for item in page._get_page_items():
yield item
Expand Down Expand Up @@ -528,7 +526,7 @@ def _process_response_data(
if data is None:
return cast(ResponseT, None)

if cast_to is UnknownResponse:
if cast_to is object:
return cast(ResponseT, data)

try:
Expand Down Expand Up @@ -970,7 +968,7 @@ def _retry_request(

def _request_api_list(
self,
model: Type[ModelT],
model: Type[object],
page: Type[SyncPageT],
options: FinalRequestOptions,
) -> SyncPageT:
Expand Down Expand Up @@ -1132,7 +1130,7 @@ def get_api_list(
self,
path: str,
*,
model: Type[ModelT],
model: Type[object],
page: Type[SyncPageT],
body: Body | None = None,
options: RequestOptions = {},
Expand Down Expand Up @@ -1434,10 +1432,10 @@ async def _retry_request(

def _request_api_list(
self,
model: Type[ModelT],
model: Type[_T],
page: Type[AsyncPageT],
options: FinalRequestOptions,
) -> AsyncPaginator[ModelT, AsyncPageT]:
) -> AsyncPaginator[_T, AsyncPageT]:
return AsyncPaginator(client=self, options=options, page_cls=page, model=model)

@overload
Expand Down Expand Up @@ -1584,13 +1582,12 @@ def get_api_list(
self,
path: str,
*,
# TODO: support paginating `str`
model: Type[ModelT],
model: Type[_T],
page: Type[AsyncPageT],
body: Body | None = None,
options: RequestOptions = {},
method: str = "get",
) -> AsyncPaginator[ModelT, AsyncPageT]:
) -> AsyncPaginator[_T, AsyncPageT]:
opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options)
return self._request_api_list(model, page, opts)

Expand Down
4 changes: 2 additions & 2 deletions src/anthropic/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import httpx

from ._types import NoneType, UnknownResponse, BinaryResponseContent
from ._types import NoneType, BinaryResponseContent
from ._utils import is_given, extract_type_var_from_base
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER
Expand Down Expand Up @@ -162,7 +162,7 @@ def _parse(self) -> R:
# `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
# this function would become unsafe but a type checker would not report an error.
if (
cast_to is not UnknownResponse
cast_to is not object
and not origin is list
and not origin is dict
and not origin is Union
Expand Down
17 changes: 11 additions & 6 deletions src/anthropic/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,6 @@ class RequestOptions(TypedDict, total=False):
idempotency_key: str


# Sentinel class used when the response type is an object with an unknown schema
class UnknownResponse:
...


# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
Expand Down Expand Up @@ -339,7 +334,17 @@ def get(self, __key: str) -> str | None:

ResponseT = TypeVar(
"ResponseT",
bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
bound=Union[
object,
str,
None,
"BaseModel",
List[Any],
Dict[str, Any],
Response,
ModelBuilderProtocol,
BinaryResponseContent,
],
)

StrBytesIntFloat = Union[str, bytes, int, float]
Expand Down
8 changes: 1 addition & 7 deletions src/anthropic/resources/beta/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@

import httpx

from ..._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
)
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ..._utils import required_args, maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
Expand Down
8 changes: 1 addition & 7 deletions src/anthropic/resources/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@
import httpx

from ..types import Completion, completion_create_params
from .._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
)
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import required_args, maybe_transform
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
Expand Down

0 comments on commit 5e5e1e7

Please sign in to comment.