diff --git a/tests/lib/streaming/test_messages.py b/tests/lib/streaming/test_messages.py index e7b649bd..c09baade 100644 --- a/tests/lib/streaming/test_messages.py +++ b/tests/lib/streaming/test_messages.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import inspect from typing import Any, TypeVar, cast from typing_extensions import Iterator, AsyncIterator, override @@ -16,7 +17,7 @@ base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") api_key = "my-anthropic-api-key" -client = Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True) +sync_client = Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True) async_client = AsyncAnthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True) _T = TypeVar("_T") @@ -113,7 +114,7 @@ class TestSyncMessages: def test_basic_response(self, respx_mock: MockRouter) -> None: respx_mock.post("/v1/messages").mock(return_value=httpx.Response(200, content=basic_response())) - with client.messages.stream( + with sync_client.messages.stream( max_tokens=1024, messages=[ { @@ -133,7 +134,7 @@ def test_basic_response(self, respx_mock: MockRouter) -> None: def test_context_manager(self, respx_mock: MockRouter) -> None: respx_mock.post("/v1/messages").mock(return_value=httpx.Response(200, content=basic_response())) - with client.messages.stream( + with sync_client.messages.stream( max_tokens=1024, messages=[ { @@ -190,3 +191,35 @@ async def test_context_manager(self, respx_mock: MockRouter) -> None: # response should be closed even if the body isn't read assert stream.response.is_closed + + +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +def test_stream_method_definition_in_sync(sync: bool) -> None: + client: Anthropic | AsyncAnthropic = sync_client if sync else async_client + + sig = inspect.signature(client.messages.stream) + generated_sig = inspect.signature(client.messages.create) + + errors: list[str] = [] + + for name, generated_param in generated_sig.parameters.items(): + if name == "stream": + # intentionally excluded + continue + + custom_param = sig.parameters.get(name) + if not custom_param: + errors.append(f"the `{name}` param is missing") + continue + + if custom_param.annotation != generated_param.annotation: + errors.append( + f"types for the `{name}` param are do not match; generated={repr(generated_param.annotation)} custom={repr(generated_param.annotation)}" + ) + continue + + if errors: + raise AssertionError( + f"{len(errors)} errors encountered with the {'sync' if sync else 'async'} client `messages.stream()` method:\n\n" + + "\n\n".join(errors) + )