Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(internal): support more input types #358

Merged
merged 1 commit into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/anthropic/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@
FileContent,
RequestFiles,
HttpxFileTypes,
Base64FileInput,
HttpxFileContent,
HttpxRequestFiles,
)
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t


def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)


def is_file_content(obj: object) -> TypeGuard[FileContent]:
return (
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
Expand Down
2 changes: 2 additions & 0 deletions src/anthropic/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@
ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
ProxiesTypes = Union[str, Proxy, ProxiesDict]
if TYPE_CHECKING:
Base64FileInput = Union[IO[bytes], PathLike[str]]
FileContent = Union[IO[bytes], bytes, PathLike[str]]
else:
Base64FileInput = Union[IO[bytes], PathLike]
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
FileTypes = Union[
# file (or bytes)
Expand Down
39 changes: 38 additions & 1 deletion src/anthropic/_utils/_transform.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from __future__ import annotations

import io
import base64
import pathlib
from typing import Any, Mapping, TypeVar, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, override, get_type_hints

import anyio
import pydantic

from ._utils import (
is_list,
is_mapping,
is_iterable,
)
from .._files import is_base64_file_input
from ._typing import (
is_list_type,
is_union_type,
Expand All @@ -29,7 +34,7 @@
# TODO: ensure works correctly with forward references in all cases


PropertyFormat = Literal["iso8601", "custom"]
PropertyFormat = Literal["iso8601", "base64", "custom"]


class PropertyInfo:
Expand Down Expand Up @@ -201,6 +206,22 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)

if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None

if isinstance(data, pathlib.Path):
binary = data.read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()

if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()

if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")

return base64.b64encode(binary).decode("ascii")

return data


Expand Down Expand Up @@ -323,6 +344,22 @@ async def _async_format_data(data: object, format_: PropertyFormat, format_templ
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)

if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None

if isinstance(data, pathlib.Path):
binary = await anyio.Path(data).read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()

if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()

if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")

return base64.b64encode(binary).decode("ascii")

return data


Expand Down
1 change: 1 addition & 0 deletions tests/sample_file.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello, world!
29 changes: 29 additions & 0 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import io
import pathlib
from typing import Any, List, Union, TypeVar, Iterable, Optional, cast
from datetime import date, datetime
from typing_extensions import Required, Annotated, TypedDict

import pytest

from anthropic._types import Base64FileInput
from anthropic._utils import (
PropertyInfo,
transform as _transform,
Expand All @@ -17,6 +20,8 @@

_T = TypeVar("_T")

SAMPLE_FILE_PATH = pathlib.Path(__file__).parent.joinpath("sample_file.txt")


async def transform(
data: _T,
Expand Down Expand Up @@ -377,3 +382,27 @@ async def test_iterable_union_str(use_async: bool) -> None:
assert cast(Any, await transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]], use_async)) == [
{"fooBaz": "bar"}
]


class TypedDictBase64Input(TypedDict):
foo: Annotated[Union[str, Base64FileInput], PropertyInfo(format="base64")]


@parametrize
@pytest.mark.asyncio
async def test_base64_file_input(use_async: bool) -> None:
# strings are left as-is
assert await transform({"foo": "bar"}, TypedDictBase64Input, use_async) == {"foo": "bar"}

# pathlib.Path is automatically converted to base64
assert await transform({"foo": SAMPLE_FILE_PATH}, TypedDictBase64Input, use_async) == {
"foo": "SGVsbG8sIHdvcmxkIQo="
} # type: ignore[comparison-overlap]

# io instances are automatically converted to base64
assert await transform({"foo": io.StringIO("Hello, world!")}, TypedDictBase64Input, use_async) == {
"foo": "SGVsbG8sIHdvcmxkIQ=="
} # type: ignore[comparison-overlap]
assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == {
"foo": "SGVsbG8sIHdvcmxkIQ=="
} # type: ignore[comparison-overlap]