Skip to content

Commit

Permalink
gh-104050: Add more annotations to Tools/clinic.py (#104544)
Browse files Browse the repository at this point in the history
  • Loading branch information
sobolevn authored May 16, 2023
1 parent 1163782 commit a454a66
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 30 deletions.
82 changes: 55 additions & 27 deletions Tools/clinic/clinic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from collections.abc import Callable
from types import FunctionType, NoneType
from typing import Any, NamedTuple
from typing import Any, NamedTuple, NoReturn, Literal, overload

# TODO:
#
Expand Down Expand Up @@ -59,37 +59,37 @@
}

class Unspecified:
def __repr__(self):
def __repr__(self) -> str:
return '<Unspecified>'

unspecified = Unspecified()


class Null:
def __repr__(self):
def __repr__(self) -> str:
return '<Null>'

NULL = Null()


class Unknown:
def __repr__(self):
def __repr__(self) -> str:
return '<Unknown>'

unknown = Unknown()

sig_end_marker = '--'

Appender = Callable[[str], None]
Outputter = Callable[[None], str]
Outputter = Callable[[], str]

class _TextAccumulator(NamedTuple):
text: list[str]
append: Appender
output: Outputter

def _text_accumulator():
text = []
def _text_accumulator() -> _TextAccumulator:
text: list[str] = []
def output():
s = ''.join(text)
text.clear()
Expand All @@ -98,10 +98,10 @@ def output():


class TextAccumulator(NamedTuple):
text: list[str]
append: Appender
output: Outputter

def text_accumulator():
def text_accumulator() -> TextAccumulator:
"""
Creates a simple text accumulator / joiner.
Expand All @@ -115,8 +115,28 @@ def text_accumulator():
text, append, output = _text_accumulator()
return TextAccumulator(append, output)


def warn_or_fail(fail=False, *args, filename=None, line_number=None):
@overload
def warn_or_fail(
*args: object,
fail: Literal[True],
filename: str | None = None,
line_number: int | None = None,
) -> NoReturn: ...

@overload
def warn_or_fail(
*args: object,
fail: Literal[False] = False,
filename: str | None = None,
line_number: int | None = None,
) -> None: ...

def warn_or_fail(
*args: object,
fail: bool = False,
filename: str | None = None,
line_number: int | None = None,
) -> None:
joined = " ".join([str(a) for a in args])
add, output = text_accumulator()
if fail:
Expand All @@ -139,14 +159,22 @@ def warn_or_fail(fail=False, *args, filename=None, line_number=None):
sys.exit(-1)


def warn(*args, filename=None, line_number=None):
return warn_or_fail(False, *args, filename=filename, line_number=line_number)
def warn(
*args: object,
filename: str | None = None,
line_number: int | None = None,
) -> None:
return warn_or_fail(*args, filename=filename, line_number=line_number, fail=False)

def fail(*args, filename=None, line_number=None):
return warn_or_fail(True, *args, filename=filename, line_number=line_number)
def fail(
*args: object,
filename: str | None = None,
line_number: int | None = None,
) -> NoReturn:
warn_or_fail(*args, filename=filename, line_number=line_number, fail=True)


def quoted_for_c_string(s):
def quoted_for_c_string(s: str) -> str:
for old, new in (
('\\', '\\\\'), # must be first!
('"', '\\"'),
Expand All @@ -155,13 +183,13 @@ def quoted_for_c_string(s):
s = s.replace(old, new)
return s

def c_repr(s):
def c_repr(s: str) -> str:
return '"' + s + '"'


is_legal_c_identifier = re.compile('^[A-Za-z_][A-Za-z0-9_]*$').match

def is_legal_py_identifier(s):
def is_legal_py_identifier(s: str) -> bool:
return all(is_legal_c_identifier(field) for field in s.split('.'))

# identifiers that are okay in Python but aren't a good idea in C.
Expand All @@ -174,7 +202,7 @@ def is_legal_py_identifier(s):
typedef typeof union unsigned void volatile while
""".strip().split())

def ensure_legal_c_identifier(s):
def ensure_legal_c_identifier(s: str) -> str:
# for now, just complain if what we're given isn't legal
if not is_legal_c_identifier(s):
fail("Illegal C identifier: {}".format(s))
Expand All @@ -183,22 +211,22 @@ def ensure_legal_c_identifier(s):
return s + "_value"
return s

def rstrip_lines(s):
def rstrip_lines(s: str) -> str:
text, add, output = _text_accumulator()
for line in s.split('\n'):
add(line.rstrip())
add('\n')
text.pop()
return output()

def format_escape(s):
def format_escape(s: str) -> str:
# double up curly-braces, this string will be used
# as part of a format_map() template later
s = s.replace('{', '{{')
s = s.replace('}', '}}')
return s

def linear_format(s, **kwargs):
def linear_format(s: str, **kwargs: str) -> str:
"""
Perform str.format-like substitution, except:
* The strings substituted must be on lines by
Expand Down Expand Up @@ -242,7 +270,7 @@ def linear_format(s, **kwargs):

return output()[:-1]

def indent_all_lines(s, prefix):
def indent_all_lines(s: str, prefix: str) -> str:
"""
Returns 's', with 'prefix' prepended to all lines.
Expand All @@ -263,7 +291,7 @@ def indent_all_lines(s, prefix):
final.append(last)
return ''.join(final)

def suffix_all_lines(s, suffix):
def suffix_all_lines(s: str, suffix: str) -> str:
"""
Returns 's', with 'suffix' appended to all lines.
Expand All @@ -283,7 +311,7 @@ def suffix_all_lines(s, suffix):
return ''.join(final)


def version_splitter(s):
def version_splitter(s: str) -> tuple[int, ...]:
"""Splits a version string into a tuple of integers.
The following ASCII characters are allowed, and employ
Expand All @@ -294,7 +322,7 @@ def version_splitter(s):
(This permits Python-style version strings such as "1.4b3".)
"""
version = []
accumulator = []
accumulator: list[str] = []
def flush():
if not accumulator:
raise ValueError('Unsupported version string: ' + repr(s))
Expand All @@ -314,7 +342,7 @@ def flush():
flush()
return tuple(version)

def version_comparitor(version1, version2):
def version_comparitor(version1: str, version2: str) -> Literal[-1, 0, 1]:
iterator = itertools.zip_longest(version_splitter(version1), version_splitter(version2), fillvalue=0)
for i, (a, b) in enumerate(iterator):
if a < b:
Expand Down
7 changes: 4 additions & 3 deletions Tools/clinic/cpp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import sys
from collections.abc import Callable
from typing import NoReturn


TokenAndCondition = tuple[str, str]
Expand Down Expand Up @@ -30,7 +31,7 @@ class Monitor:
is_a_simple_defined: Callable[[str], re.Match[str] | None]
is_a_simple_defined = re.compile(r'^defined\s*\(\s*[A-Za-z0-9_]+\s*\)$').match

def __init__(self, filename=None, *, verbose: bool = False):
def __init__(self, filename: str | None = None, *, verbose: bool = False) -> None:
self.stack: TokenStack = []
self.in_comment = False
self.continuation: str | None = None
Expand All @@ -55,7 +56,7 @@ def condition(self) -> str:
"""
return " && ".join(condition for token, condition in self.stack)

def fail(self, *a):
def fail(self, *a: object) -> NoReturn:
if self.filename:
filename = " " + self.filename
else:
Expand All @@ -64,7 +65,7 @@ def fail(self, *a):
print(" ", ' '.join(str(x) for x in a))
sys.exit(-1)

def close(self):
def close(self) -> None:
if self.stack:
self.fail("Ended file while still in a preprocessor conditional block!")

Expand Down
1 change: 1 addition & 0 deletions Tools/clinic/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ strict_concatenate = True
warn_redundant_casts = True
warn_unused_ignores = True
warn_unused_configs = True
warn_unreachable = True
files = Tools/clinic/

0 comments on commit a454a66

Please sign in to comment.