Skip to content

Commit

Permalink
chore(internal): update base client (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot authored Sep 6, 2023
1 parent 11196b7 commit 8e0dca4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
20 changes: 17 additions & 3 deletions src/anthropic/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
overload,
)
from functools import lru_cache
from typing_extensions import Literal, get_origin
from typing_extensions import Literal, get_args, get_origin

import anyio
import httpx
Expand Down Expand Up @@ -458,6 +458,14 @@ def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, o
serialized[key] = value
return serialized

def _extract_stream_chunk_type(self, stream_cls: type) -> type:
args = get_args(stream_cls)
if not args:
raise TypeError(
f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
)
return cast(type, args[0])

def _process_response(
self,
*,
Expand Down Expand Up @@ -793,7 +801,10 @@ def _request(
raise APIConnectionError(request=request) from err

if stream:
stream_cls = stream_cls or cast("type[_StreamT] | None", self._default_stream_cls)
if stream_cls:
return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self)

stream_cls = cast("type[_StreamT] | None", self._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return stream_cls(cast_to=cast_to, response=response, client=self)
Expand Down Expand Up @@ -1156,7 +1167,10 @@ async def _request(
raise APIConnectionError(request=request) from err

if stream:
stream_cls = stream_cls or cast("type[_AsyncStreamT] | None", self._default_stream_cls)
if stream_cls:
return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self)

stream_cls = cast("type[_AsyncStreamT] | None", self._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return stream_cls(cast_to=cast_to, response=response, client=self)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
import pytest
from respx import MockRouter

from anthropic import Stream, Anthropic, AsyncStream, AsyncAnthropic
from anthropic import Anthropic, AsyncAnthropic
from anthropic._types import Omit
from anthropic._models import BaseModel, FinalRequestOptions
from anthropic._streaming import Stream, AsyncStream
from anthropic._base_client import BaseClient, make_request_options

base_url = os.environ.get("API_BASE_URL", "http://127.0.0.1:4010")
Expand Down

0 comments on commit 8e0dca4

Please sign in to comment.