diff --git a/stdlib/re.pyi b/stdlib/re.pyi index 4962ab8edad9..12771441b779 100644 --- a/stdlib/re.pyi +++ b/stdlib/re.pyi @@ -67,7 +67,9 @@ class Match(Generic[AnyStr]): @overload def expand(self: Match[str], template: str) -> str: ... @overload - def expand(self: Match[bytes], template: ReadableBuffer) -> bytes: ... + def expand(self: Match[bytes], template: ReadableBuffer) -> bytes: ... # type: ignore[misc] + @overload + def expand(self, template: AnyStr) -> AnyStr: ... # group() returns "AnyStr" or "AnyStr | None", depending on the pattern. @overload def group(self, __group: Literal[0] = ...) -> AnyStr: ... @@ -115,46 +117,62 @@ class Pattern(Generic[AnyStr]): @overload def search(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... @overload - def search(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... + def search(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... # type: ignore[misc] + @overload + def search(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... @overload def match(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... @overload - def match(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... + def match(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... # type: ignore[misc] + @overload + def match(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... @overload def fullmatch(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Match[str] | None: ... @overload - def fullmatch(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... + def fullmatch(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Match[bytes] | None: ... # type: ignore[misc] + @overload + def fullmatch(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Match[AnyStr] | None: ... @overload def split(self: Pattern[str], string: str, maxsplit: int = ...) -> list[str | Any]: ... @overload def split(self: Pattern[bytes], string: ReadableBuffer, maxsplit: int = ...) -> list[bytes | Any]: ... + @overload + def split(self, string: AnyStr, maxsplit: int = ...) -> list[AnyStr | Any]: ... # return type depends on the number of groups in the pattern @overload def findall(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> list[Any]: ... @overload def findall(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> list[Any]: ... @overload + def findall(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> list[AnyStr]: ... + @overload def finditer(self: Pattern[str], string: str, pos: int = ..., endpos: int = ...) -> Iterator[Match[str]]: ... @overload - def finditer(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Iterator[Match[bytes]]: ... + def finditer(self: Pattern[bytes], string: ReadableBuffer, pos: int = ..., endpos: int = ...) -> Iterator[Match[bytes]]: ... # type: ignore[misc] + @overload + def finditer(self, string: AnyStr, pos: int = ..., endpos: int = ...) -> Iterator[Match[AnyStr]]: ... @overload def sub(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> str: ... @overload - def sub( + def sub( # type: ignore[misc] self: Pattern[bytes], repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer], string: ReadableBuffer, count: int = ..., ) -> bytes: ... @overload + def sub(self, repl: AnyStr | Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> AnyStr: ... + @overload def subn(self: Pattern[str], repl: str | Callable[[Match[str]], str], string: str, count: int = ...) -> tuple[str, int]: ... @overload - def subn( + def subn( # type: ignore[misc] self: Pattern[bytes], repl: ReadableBuffer | Callable[[Match[bytes]], ReadableBuffer], string: ReadableBuffer, count: int = ..., ) -> tuple[bytes, int]: ... + @overload + def subn(self, repl: AnyStr | Callable[[Match[AnyStr]], AnyStr], string: AnyStr, count: int = ...) -> tuple[AnyStr, int]: ... def __copy__(self) -> Pattern[AnyStr]: ... def __deepcopy__(self, __memo: Any) -> Pattern[AnyStr]: ... if sys.version_info >= (3, 9): diff --git a/test_cases/stdlib/check_re.py b/test_cases/stdlib/check_re.py new file mode 100644 index 000000000000..b6ab2b0d59d2 --- /dev/null +++ b/test_cases/stdlib/check_re.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import mmap +import re +import typing as t +from typing_extensions import assert_type + + +def check_search(str_pat: re.Pattern[str], bytes_pat: re.Pattern[bytes]) -> None: + assert_type(str_pat.search("x"), t.Optional[t.Match[str]]) + assert_type(bytes_pat.search(b"x"), t.Optional[t.Match[bytes]]) + assert_type(bytes_pat.search(bytearray(b"x")), t.Optional[t.Match[bytes]]) + assert_type(bytes_pat.search(mmap.mmap(0, 10)), t.Optional[t.Match[bytes]]) + + +def check_search_with_AnyStr(pattern: re.Pattern[t.AnyStr], string: t.AnyStr) -> re.Match[t.AnyStr]: + """See issue #9591""" + match = pattern.search(string) + if match is None: + raise ValueError(f"'{string!r}' does not match {pattern!r}") + return match + + +def check_no_ReadableBuffer_false_negatives() -> None: + re.compile("foo").search(bytearray(b"foo")) # type: ignore + re.compile("foo").search(mmap.mmap(0, 10)) # type: ignore diff --git a/test_cases/stdlib/typing/check_pattern.py b/test_cases/stdlib/typing/check_pattern.py deleted file mode 100644 index ec5c1c4f6141..000000000000 --- a/test_cases/stdlib/typing/check_pattern.py +++ /dev/null @@ -1,10 +0,0 @@ -from __future__ import annotations - -from typing import Match, Optional, Pattern -from typing_extensions import assert_type - - -def test_search(str_pat: Pattern[str], bytes_pat: Pattern[bytes]) -> None: - assert_type(str_pat.search("x"), Optional[Match[str]]) - assert_type(bytes_pat.search(b"x"), Optional[Match[bytes]]) - assert_type(bytes_pat.search(bytearray(b"x")), Optional[Match[bytes]])