Skip to content

Commit

Permalink
[mypyc] Support __pos__ and __abs__ dunders (#13490)
Browse files Browse the repository at this point in the history
Calls to these dunders on native classes will be specialized to use a
direct method call instead of using PyNumber_Absolute.

Also calls to abs() on any types have been optimized. They no longer
involve a builtins dictionary lookup. It's probably possible to write a
C helper function for abs(int) to avoid the C-API entirely for native
integers, but I don't feel skilled enough to do that yet.
  • Loading branch information
ichard26 authored Sep 1, 2022
1 parent 7ffaf23 commit 3c7e216
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 4 deletions.
9 changes: 7 additions & 2 deletions mypyc/codegen/emitclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,15 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
AS_SEQUENCE_SLOT_DEFS: SlotTable = {"__contains__": ("sq_contains", generate_contains_wrapper)}

AS_NUMBER_SLOT_DEFS: SlotTable = {
# Unary operations.
"__bool__": ("nb_bool", generate_bool_wrapper),
"__neg__": ("nb_negative", generate_dunder_wrapper),
"__invert__": ("nb_invert", generate_dunder_wrapper),
"__int__": ("nb_int", generate_dunder_wrapper),
"__float__": ("nb_float", generate_dunder_wrapper),
"__neg__": ("nb_negative", generate_dunder_wrapper),
"__pos__": ("nb_positive", generate_dunder_wrapper),
"__abs__": ("nb_absolute", generate_dunder_wrapper),
"__invert__": ("nb_invert", generate_dunder_wrapper),
# Binary operations.
"__add__": ("nb_add", generate_bin_op_wrapper),
"__radd__": ("nb_add", generate_bin_op_wrapper),
"__sub__": ("nb_subtract", generate_bin_op_wrapper),
Expand All @@ -97,6 +101,7 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"__rxor__": ("nb_xor", generate_bin_op_wrapper),
"__matmul__": ("nb_matrix_multiply", generate_bin_op_wrapper),
"__rmatmul__": ("nb_matrix_multiply", generate_bin_op_wrapper),
# In-place binary operations.
"__iadd__": ("nb_inplace_add", generate_dunder_wrapper),
"__isub__": ("nb_inplace_subtract", generate_dunder_wrapper),
"__imul__": ("nb_inplace_multiply", generate_dunder_wrapper),
Expand Down
1 change: 1 addition & 0 deletions mypyc/doc/native_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Functions
* ``cast(<type>, obj)``
* ``type(obj)``
* ``len(obj)``
* ``abs(obj)``
* ``id(obj)``
* ``iter(obj)``
* ``next(iter: Iterator)``
Expand Down
2 changes: 2 additions & 0 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,6 +1486,8 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
if isinstance(typ, RInstance):
if expr_op == "-":
method = "__neg__"
elif expr_op == "+":
method = "__pos__"
elif expr_op == "~":
method = "__invert__"
else:
Expand Down
14 changes: 14 additions & 0 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from mypy.types import AnyType, TypeOfAny
from mypyc.ir.ops import BasicBlock, Integer, RaiseStandardError, Register, Unreachable, Value
from mypyc.ir.rtypes import (
RInstance,
RTuple,
RType,
bool_rprimitive,
Expand Down Expand Up @@ -138,6 +139,19 @@ def translate_globals(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Va
return None


@specialize_function("builtins.abs")
def translate_abs(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
"""Specialize calls on native classes that implement __abs__."""
if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
arg = expr.args[0]
arg_typ = builder.node_type(arg)
if isinstance(arg_typ, RInstance) and arg_typ.class_ir.has_method("__abs__"):
obj = builder.accept(arg)
return builder.gen_method_call(obj, "__abs__", [], None, expr.line)

return None


@specialize_function("builtins.len")
def translate_len(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
Expand Down
10 changes: 10 additions & 0 deletions mypyc/primitives/generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,16 @@
priority=0,
)

# abs(obj)
function_op(
name="builtins.abs",
arg_types=[object_rprimitive],
return_type=object_rprimitive,
c_function_name="PyNumber_Absolute",
error_kind=ERR_MAGIC,
priority=0,
)

# obj1[obj2]
method_op(
name="__getitem__",
Expand Down
12 changes: 10 additions & 2 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import (
TypeVar, Generic, List, Iterator, Iterable, Dict, Optional, Tuple, Any, Set,
overload, Mapping, Union, Callable, Sequence, FrozenSet
overload, Mapping, Union, Callable, Sequence, FrozenSet, Protocol
)

T = TypeVar('T')
Expand All @@ -12,6 +12,10 @@
K = TypeVar('K') # for keys in mapping
V = TypeVar('V') # for values in mapping

class __SupportsAbs(Protocol[T_co]):
def __abs__(self) -> T_co: pass


class object:
def __init__(self) -> None: pass
def __eq__(self, x: object) -> bool: pass
Expand Down Expand Up @@ -40,6 +44,7 @@ def __truediv__(self, x: float) -> float: pass
def __mod__(self, x: int) -> int: pass
def __neg__(self) -> int: pass
def __pos__(self) -> int: pass
def __abs__(self) -> int: pass
def __invert__(self) -> int: pass
def __and__(self, n: int) -> int: pass
def __or__(self, n: int) -> int: pass
Expand Down Expand Up @@ -88,6 +93,9 @@ def __sub__(self, n: float) -> float: pass
def __mul__(self, n: float) -> float: pass
def __truediv__(self, n: float) -> float: pass
def __neg__(self) -> float: pass
def __pos__(self) -> float: pass
def __abs__(self) -> float: pass
def __invert__(self) -> float: pass

class complex:
def __init__(self, x: object, y: object = None) -> None: pass
Expand Down Expand Up @@ -296,7 +304,7 @@ def zip(x: Iterable[T], y: Iterable[S]) -> Iterator[Tuple[T, S]]: ...
@overload
def zip(x: Iterable[T], y: Iterable[S], z: Iterable[V]) -> Iterator[Tuple[T, S, V]]: ...
def eval(e: str) -> Any: ...
def abs(x: float) -> float: ...
def abs(x: __SupportsAbs[T]) -> T: ...
def exit() -> None: ...
def min(x: T, y: T) -> T: ...
def max(x: T, y: T) -> T: ...
Expand Down
22 changes: 22 additions & 0 deletions mypyc/test-data/irbuild-any.test
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,25 @@ L6:
r4 = unbox(int, r3)
n = r4
return 1

[case testAbsSpecialization]
# Specialization of native classes that implement __abs__ is checked in
# irbuild-dunders.test
def f() -> None:
a = abs(1)
b = abs(1.1)
[out]
def f():
r0, r1 :: object
r2, a :: int
r3, r4, b :: float
L0:
r0 = object 1
r1 = PyNumber_Absolute(r0)
r2 = unbox(int, r1)
a = r2
r3 = 1.1
r4 = PyNumber_Absolute(r3)
b = r4
return 1

19 changes: 19 additions & 0 deletions mypyc/test-data/irbuild-dunders.test
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,19 @@ class C:
def __float__(self) -> float:
return 4.0

def __pos__(self) -> int:
return 5

def __abs__(self) -> int:
return 6

def f(c: C) -> None:
-c
~c
int(c)
float(c)
+c
abs(c)
[out]
def C.__neg__(self):
self :: __main__.C
Expand All @@ -172,16 +180,27 @@ def C.__float__(self):
L0:
r0 = 4.0
return r0
def C.__pos__(self):
self :: __main__.C
L0:
return 10
def C.__abs__(self):
self :: __main__.C
L0:
return 12
def f(c):
c :: __main__.C
r0, r1 :: int
r2, r3, r4, r5 :: object
r6, r7 :: int
L0:
r0 = c.__neg__()
r1 = c.__invert__()
r2 = load_address PyLong_Type
r3 = PyObject_CallFunctionObjArgs(r2, c, 0)
r4 = load_address PyFloat_Type
r5 = PyObject_CallFunctionObjArgs(r4, c, 0)
r6 = c.__pos__()
r7 = c.__abs__()
return 1

11 changes: 11 additions & 0 deletions mypyc/test-data/run-dunders.test
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,22 @@ class C:
def __float__(self) -> float:
return float(self.x + 4)

def __pos__(self) -> int:
return self.x + 5

def __abs__(self) -> int:
return abs(self.x) + 6


def test_unary_dunders_generic() -> None:
a: Any = C(10)

assert -a == 11
assert ~a == 12
assert int(a) == 13
assert float(a) == 14.0
assert +a == 15
assert abs(a) == 16

def test_unary_dunders_native() -> None:
c = C(10)
Expand All @@ -347,6 +356,8 @@ def test_unary_dunders_native() -> None:
assert ~c == 12
assert int(c) == 13
assert float(c) == 14.0
assert +c == 15
assert abs(c) == 16

[case testDundersBinarySimple]
from typing import Any
Expand Down

0 comments on commit 3c7e216

Please sign in to comment.