Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow inferring +int to be a Literal #16910

Merged
merged 4 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4437,6 +4437,10 @@ def try_getting_int_literals(self, index: Expression) -> list[int] | None:
operand = index.expr
if isinstance(operand, IntExpr):
return [-1 * operand.value]
if index.op == "+":
operand = index.expr
if isinstance(operand, IntExpr):
return [operand.value]
typ = get_proper_type(self.accept(index))
if isinstance(typ, Instance) and typ.last_known_value is not None:
typ = typ.last_known_value
Expand Down
9 changes: 6 additions & 3 deletions mypy/exprtotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,12 @@ def expr_to_unanalyzed_type(
elif isinstance(expr, UnaryExpr):
typ = expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax)
if isinstance(typ, RawExpressionType):
if isinstance(typ.literal_value, int) and expr.op == "-":
typ.literal_value *= -1
return typ
if isinstance(typ.literal_value, int):
if expr.op == "-":
typ.literal_value *= -1
return typ
elif expr.op == "+":
return typ
raise TypeTranslationError()
elif isinstance(expr, IntExpr):
return RawExpressionType(expr.value, "builtins.int", line=expr.line, column=expr.column)
Expand Down
25 changes: 19 additions & 6 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
return int_pow_callback
elif fullname == "builtins.int.__neg__":
return int_neg_callback
elif fullname == "builtins.int.__pos__":
return int_pos_callback
elif fullname in ("builtins.tuple.__mul__", "builtins.tuple.__rmul__"):
return tuple_mul_callback
elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}:
Expand Down Expand Up @@ -471,32 +473,43 @@ def int_pow_callback(ctx: MethodContext) -> Type:
return ctx.default_return_type


def int_neg_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for int.__neg__.
def int_neg_callback(ctx: MethodContext, multiplier: int = -1) -> Type:
"""Infer a more precise return type for int.__neg__ and int.__pos__.

This is mainly used to infer the return type as LiteralType
if the original underlying object is a LiteralType object
if the original underlying object is a LiteralType object.
"""
if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None:
value = ctx.type.last_known_value.value
fallback = ctx.type.last_known_value.fallback
if isinstance(value, int):
if is_literal_type_like(ctx.api.type_context[-1]):
return LiteralType(value=-value, fallback=fallback)
return LiteralType(value=multiplier * value, fallback=fallback)
else:
return ctx.type.copy_modified(
last_known_value=LiteralType(
value=-value, fallback=ctx.type, line=ctx.type.line, column=ctx.type.column
value=multiplier * value,
fallback=ctx.type,
line=ctx.type.line,
column=ctx.type.column,
)
)
elif isinstance(ctx.type, LiteralType):
value = ctx.type.value
fallback = ctx.type.fallback
if isinstance(value, int):
return LiteralType(value=-value, fallback=fallback)
return LiteralType(value=multiplier * value, fallback=fallback)
return ctx.default_return_type


def int_pos_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for int.__pos__.

This is identical to __neg__, except the value is not inverted.
"""
return int_neg_callback(ctx, +1)


def tuple_mul_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for tuple.__mul__ and tuple.__rmul__.

Expand Down
13 changes: 12 additions & 1 deletion test-data/unit/check-literal.test
Original file line number Diff line number Diff line change
Expand Up @@ -397,29 +397,36 @@ from typing_extensions import Literal
a1: Literal[4]
b1: Literal[0x2a]
c1: Literal[-300]
d1: Literal[+8]

reveal_type(a1) # N: Revealed type is "Literal[4]"
reveal_type(b1) # N: Revealed type is "Literal[42]"
reveal_type(c1) # N: Revealed type is "Literal[-300]"
reveal_type(d1) # N: Revealed type is "Literal[8]"

a2t = Literal[4]
b2t = Literal[0x2a]
c2t = Literal[-300]
d2t = Literal[+8]
a2: a2t
b2: b2t
c2: c2t
d2: d2t

reveal_type(a2) # N: Revealed type is "Literal[4]"
reveal_type(b2) # N: Revealed type is "Literal[42]"
reveal_type(c2) # N: Revealed type is "Literal[-300]"
reveal_type(d2) # N: Revealed type is "Literal[8]"

def f1(x: Literal[4]) -> Literal[4]: pass
def f2(x: Literal[0x2a]) -> Literal[0x2a]: pass
def f3(x: Literal[-300]) -> Literal[-300]: pass
def f4(x: Literal[+8]) -> Literal[+8]: pass

reveal_type(f1) # N: Revealed type is "def (x: Literal[4]) -> Literal[4]"
reveal_type(f2) # N: Revealed type is "def (x: Literal[42]) -> Literal[42]"
reveal_type(f3) # N: Revealed type is "def (x: Literal[-300]) -> Literal[-300]"
reveal_type(f4) # N: Revealed type is "def (x: Literal[8]) -> Literal[8]"
[builtins fixtures/tuple.pyi]
[out]

Expand Down Expand Up @@ -2747,6 +2754,9 @@ d: Literal[1] = 1
e: Literal[2] = 2
f: Literal[+1] = 1
g: Literal[+2] = 2
h: Literal[1] = +1
i: Literal[+2] = 2
j: Literal[+3] = +3

x: Literal[+True] = True # E: Invalid type: Literal[...] cannot contain arbitrary expressions
y: Literal[-True] = -1 # E: Invalid type: Literal[...] cannot contain arbitrary expressions
Expand All @@ -2759,14 +2769,15 @@ from typing_extensions import Literal, Final

ONE: Final = 1
x: Literal[-1] = -ONE
y: Literal[+1] = +ONE

TWO: Final = 2
THREE: Final = 3

err_code = -TWO
if bool():
err_code = -THREE
[builtins fixtures/float.pyi]
[builtins fixtures/ops.pyi]

[case testAliasForEnumTypeAsLiteral]
from typing_extensions import Literal
Expand Down
4 changes: 3 additions & 1 deletion test-data/unit/check-tuples.test
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,12 @@ if int():
b = t1[-1]
if int():
a = t1[(0)]
if int():
b = t1[+1]
if int():
x = t3[0:3] # type (A, B, C)
if int():
y = t3[0:5:2] # type (A, C, E)
y = t3[0:+5:2] # type (A, C, E)
if int():
x = t3[:-2] # type (A, B, C)

Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/tuple.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class classmethod: pass
# We need int and slice for indexing tuples.
class int:
def __neg__(self) -> 'int': pass
def __pos__(self) -> 'int': pass
class float: pass
class slice: pass
class bool(int): pass
Expand Down
Loading