From 149382c6380e62a53c21ca6bac9c2e65c350f80f Mon Sep 17 00:00:00 2001
From: AlexWaygood <alex.waygood@gmail.com>
Date: Fri, 10 Mar 2023 13:39:24 +0000
Subject: [PATCH] [alt] typing: accept buffers in `IO.write`

Co-Authored-by: JelleZijstra <jelle.ziljstra@gmail.com>
---
 stdlib/codecs.pyi                    |  3 ++-
 stdlib/http/client.pyi               |  2 +-
 stdlib/io.pyi                        | 12 ++++++------
 stdlib/lzma.pyi                      |  2 +-
 stdlib/tempfile.pyi                  | 24 ++++++++++++++++++++++--
 stdlib/typing.pyi                    | 16 +++++++++++++++-
 test_cases/stdlib/typing/check_io.py | 21 +++++++++++++++++++++
 7 files changed, 68 insertions(+), 12 deletions(-)
 create mode 100644 test_cases/stdlib/typing/check_io.py

diff --git a/stdlib/codecs.pyi b/stdlib/codecs.pyi
index 5a22853b6aee..3f6d2d3d16b7 100644
--- a/stdlib/codecs.pyi
+++ b/stdlib/codecs.pyi
@@ -272,8 +272,9 @@ class StreamRecoder(BinaryIO):
     def readlines(self, sizehint: int | None = None) -> list[bytes]: ...
     def __next__(self) -> bytes: ...
     def __iter__(self) -> Self: ...
+    # Base class accepts more types than just bytes
     def write(self, data: bytes) -> None: ...  # type: ignore[override]
-    def writelines(self, list: Iterable[bytes]) -> None: ...
+    def writelines(self, list: Iterable[bytes]) -> None: ...  # type: ignore[override]
     def reset(self) -> None: ...
     def __getattr__(self, name: str) -> Any: ...
     def __enter__(self) -> Self: ...
diff --git a/stdlib/http/client.pyi b/stdlib/http/client.pyi
index 1f16bdc2dbab..cc142fbb23fd 100644
--- a/stdlib/http/client.pyi
+++ b/stdlib/http/client.pyi
@@ -101,7 +101,7 @@ class HTTPMessage(email.message.Message):
 
 def parse_headers(fp: io.BufferedIOBase, _class: Callable[[], email.message.Message] = ...) -> HTTPMessage: ...
 
-class HTTPResponse(io.BufferedIOBase, BinaryIO):
+class HTTPResponse(io.BufferedIOBase, BinaryIO):  # type: ignore[misc]  # incompatible method definitions in the base classes
     msg: HTTPMessage
     headers: HTTPMessage
     version: int
diff --git a/stdlib/io.pyi b/stdlib/io.pyi
index c3e07bacbe5a..c114f839594f 100644
--- a/stdlib/io.pyi
+++ b/stdlib/io.pyi
@@ -90,7 +90,7 @@ class BufferedIOBase(IOBase):
     def read(self, __size: int | None = ...) -> bytes: ...
     def read1(self, __size: int = ...) -> bytes: ...
 
-class FileIO(RawIOBase, BinaryIO):
+class FileIO(RawIOBase, BinaryIO):  # type: ignore[misc]  # incompatible definitions of writelines in the base classes
     mode: str
     name: FileDescriptorOrPath  # type: ignore[assignment]
     def __init__(
@@ -102,7 +102,7 @@ class FileIO(RawIOBase, BinaryIO):
     def read(self, __size: int = -1) -> bytes: ...
     def __enter__(self) -> Self: ...
 
-class BytesIO(BufferedIOBase, BinaryIO):
+class BytesIO(BufferedIOBase, BinaryIO):  # type: ignore[misc]  # incompatible definitions of methods in the base classes
     def __init__(self, initial_bytes: ReadableBuffer = ...) -> None: ...
     # BytesIO does not contain a "name" field. This workaround is necessary
     # to allow BytesIO sub-classes to add this field, as it is defined
@@ -113,17 +113,17 @@ class BytesIO(BufferedIOBase, BinaryIO):
     def getbuffer(self) -> memoryview: ...
     def read1(self, __size: int | None = -1) -> bytes: ...
 
-class BufferedReader(BufferedIOBase, BinaryIO):
+class BufferedReader(BufferedIOBase, BinaryIO):  # type: ignore[misc]  # incompatible definitions of methods in the base classes
     def __enter__(self) -> Self: ...
     def __init__(self, raw: RawIOBase, buffer_size: int = ...) -> None: ...
     def peek(self, __size: int = 0) -> bytes: ...
 
-class BufferedWriter(BufferedIOBase, BinaryIO):
+class BufferedWriter(BufferedIOBase, BinaryIO):  # type: ignore[misc]  # incompatible definitions of writelines in the base classes
     def __enter__(self) -> Self: ...
     def __init__(self, raw: RawIOBase, buffer_size: int = ...) -> None: ...
     def write(self, __buffer: ReadableBuffer) -> int: ...
 
-class BufferedRandom(BufferedReader, BufferedWriter):
+class BufferedRandom(BufferedReader, BufferedWriter):  # type: ignore[misc]  # incompatible definitions of methods in the base classes
     def __enter__(self) -> Self: ...
     def seek(self, __target: int, __whence: int = 0) -> int: ...  # stubtest needs this
 
@@ -144,7 +144,7 @@ class TextIOBase(IOBase):
     def readlines(self, __hint: int = -1) -> list[str]: ...  # type: ignore[override]
     def read(self, __size: int | None = ...) -> str: ...
 
-class TextIOWrapper(TextIOBase, TextIO):
+class TextIOWrapper(TextIOBase, TextIO):  # type: ignore[misc]  # incompatible definitions of write in the base classes
     def __init__(
         self,
         buffer: IO[bytes],
diff --git a/stdlib/lzma.pyi b/stdlib/lzma.pyi
index 34bd6f3f8db1..8e296bb5b357 100644
--- a/stdlib/lzma.pyi
+++ b/stdlib/lzma.pyi
@@ -104,7 +104,7 @@ class LZMACompressor:
 
 class LZMAError(Exception): ...
 
-class LZMAFile(io.BufferedIOBase, IO[bytes]):
+class LZMAFile(io.BufferedIOBase, IO[bytes]):  # type: ignore[misc]  # incompatible definitions of writelines in the base classes
     def __init__(
         self,
         filename: _PathOrFile | None = None,
diff --git a/stdlib/tempfile.pyi b/stdlib/tempfile.pyi
index dbff6d632d02..4be0d62a7870 100644
--- a/stdlib/tempfile.pyi
+++ b/stdlib/tempfile.pyi
@@ -1,6 +1,6 @@
 import io
 import sys
-from _typeshed import BytesPath, GenericPath, StrPath, WriteableBuffer
+from _typeshed import BytesPath, GenericPath, ReadableBuffer, StrPath, WriteableBuffer
 from collections.abc import Iterable, Iterator
 from types import TracebackType
 from typing import IO, Any, AnyStr, Generic, overload
@@ -215,7 +215,17 @@ class _TemporaryFileWrapper(Generic[AnyStr], IO[AnyStr]):
     def tell(self) -> int: ...
     def truncate(self, size: int | None = ...) -> int: ...
     def writable(self) -> bool: ...
+    @overload
+    def write(self: _TemporaryFileWrapper[str], s: str) -> int: ...
+    @overload
+    def write(self: _TemporaryFileWrapper[bytes], s: ReadableBuffer) -> int: ...
+    @overload
     def write(self, s: AnyStr) -> int: ...
+    @overload
+    def writelines(self: _TemporaryFileWrapper[str], lines: Iterable[str]) -> None: ...
+    @overload
+    def writelines(self: _TemporaryFileWrapper[bytes], lines: Iterable[ReadableBuffer]) -> None: ...
+    @overload
     def writelines(self, lines: Iterable[AnyStr]) -> None: ...
 
 if sys.version_info >= (3, 11):
@@ -392,8 +402,18 @@ class SpooledTemporaryFile(IO[AnyStr], _SpooledTemporaryFileBase):
     def seek(self, offset: int, whence: int = ...) -> int: ...
     def tell(self) -> int: ...
     def truncate(self, size: int | None = None) -> None: ...  # type: ignore[override]
+    @overload
+    def write(self: SpooledTemporaryFile[str], s: str) -> int: ...
+    @overload
+    def write(self: SpooledTemporaryFile[bytes], s: ReadableBuffer) -> int: ...
+    @overload
     def write(self, s: AnyStr) -> int: ...
-    def writelines(self, iterable: Iterable[AnyStr]) -> None: ...  # type: ignore[override]
+    @overload
+    def writelines(self: SpooledTemporaryFile[str], lines: Iterable[str]) -> None: ...
+    @overload
+    def writelines(self: SpooledTemporaryFile[bytes], lines: Iterable[ReadableBuffer]) -> None: ...
+    @overload
+    def writelines(self, lines: Iterable[AnyStr]) -> None: ...
     def __iter__(self) -> Iterator[AnyStr]: ...  # type: ignore[override]
     # These exist at runtime only on 3.11+.
     def readable(self) -> bool: ...
diff --git a/stdlib/typing.pyi b/stdlib/typing.pyi
index efd61ad8bf43..0a8de1a7b538 100644
--- a/stdlib/typing.pyi
+++ b/stdlib/typing.pyi
@@ -2,7 +2,7 @@ import collections  # Needed by aliases like DefaultDict, see mypy issue 2986
 import sys
 import typing_extensions
 from _collections_abc import dict_items, dict_keys, dict_values
-from _typeshed import IdentityFunction, Incomplete, SupportsKeysAndGetItem
+from _typeshed import IdentityFunction, Incomplete, ReadableBuffer, SupportsKeysAndGetItem
 from abc import ABCMeta, abstractmethod
 from contextlib import AbstractAsyncContextManager, AbstractContextManager
 from re import Match as Match, Pattern as Pattern
@@ -687,8 +687,22 @@ class IO(Iterator[AnyStr], Generic[AnyStr]):
     @abstractmethod
     def writable(self) -> bool: ...
     @abstractmethod
+    @overload
+    def write(self: IO[str], __s: str) -> int: ...
+    @abstractmethod
+    @overload
+    def write(self: IO[bytes], __s: ReadableBuffer) -> int: ...
+    @abstractmethod
+    @overload
     def write(self, __s: AnyStr) -> int: ...
     @abstractmethod
+    @overload
+    def writelines(self: IO[str], __lines: Iterable[str]) -> None: ...
+    @abstractmethod
+    @overload
+    def writelines(self: IO[bytes], __lines: Iterable[ReadableBuffer]) -> None: ...
+    @abstractmethod
+    @overload
     def writelines(self, __lines: Iterable[AnyStr]) -> None: ...
     @abstractmethod
     def __next__(self) -> AnyStr: ...
diff --git a/test_cases/stdlib/typing/check_io.py b/test_cases/stdlib/typing/check_io.py
new file mode 100644
index 000000000000..67f16dc91765
--- /dev/null
+++ b/test_cases/stdlib/typing/check_io.py
@@ -0,0 +1,21 @@
+from __future__ import annotations
+
+import mmap
+from typing import IO, AnyStr
+
+
+def check_write(io_bytes: IO[bytes], io_str: IO[str], io_anystr: IO[AnyStr], any_str: AnyStr, buf: mmap.mmap) -> None:
+    io_bytes.write(b"")
+    io_bytes.write(buf)
+    io_bytes.write("")  # type: ignore
+    io_bytes.write(any_str)  # type: ignore
+
+    io_str.write(b"")  # type: ignore
+    io_str.write(buf)  # type: ignore
+    io_str.write("")
+    io_str.write(any_str)  # type: ignore
+
+    io_anystr.write(b"")  # type: ignore
+    io_anystr.write(buf)  # type: ignore
+    io_anystr.write("")  # type: ignore
+    io_anystr.write(any_str)