Skip to content

Commit

Permalink
core: Make IRDL operations and attributes final (#2159)
Browse files Browse the repository at this point in the history
It was previously possible to inherit an operation or attribute.
While this may be useful, it is also pretty dangerous, as now a
constraint
such as `operand_def(MyAttribute)` will also accept other attributes
such as
`MyInheritedAttribute`, which likely shouldn't be accepted.

For instance, it was previously possible to use `riscv.fastmathflags` in
the LLVM
dialect, as the attribute was inheriting `llvm.fastmathflags`.

Moreover, it is necessary for the declarative format to know that a
specific attribute
type is expected, and this isn't possible if the attribute might be
inherited.

The solution I'm proposing here is to use a new `final` decorator inside
`irdl_attr_definition`
and `irdl_op_definition` to avoid these inheritances. Note that this
`final` decorator will
trigger an error if anyone tries to subclass the final class.

Resolves: #2064
  • Loading branch information
math-fehr authored Feb 20, 2024
1 parent d5cc687 commit 137aee9
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 40 deletions.
28 changes: 28 additions & 0 deletions tests/utils/test_final.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest

from xdsl.utils.runtime_final import is_runtime_final, runtime_final


class NotFinal:
"""A non-Final class."""


@runtime_final
class Final:
"""A Final class."""


def test_is_runtime_final():
"""Check that `is_runtime_final` returns the correct value."""
assert not is_runtime_final(NotFinal)
assert is_runtime_final(Final)


def test_final_inheritance_error():
"""Check that final classes cannot be subclassed."""
with pytest.raises(TypeError, match="Subclassing final classes is restricted"):

class SubFinal(Final):
pass

SubFinal()
15 changes: 11 additions & 4 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Annotated, Generic, TypeVar, cast, overload
from typing import Annotated, Generic, Literal, TypeVar, cast, overload

from xdsl.dialects.builtin import (
AnyFloat,
Expand All @@ -20,13 +20,14 @@
UnrankedTensorType,
VectorType,
)
from xdsl.dialects.llvm import FastMathAttr as LLVMFastMathAttr
from xdsl.dialects.llvm import FastMathAttrBase, FastMathFlag
from xdsl.ir import Attribute, Dialect, Operation, OpResult, SSAValue
from xdsl.irdl import (
AnyOf,
ConstraintVar,
IRDLOperation,
Operand,
irdl_attr_definition,
irdl_op_definition,
operand_def,
opt_prop_def,
Expand Down Expand Up @@ -80,13 +81,19 @@
]


class FastMathFlagsAttr(LLVMFastMathAttr):
@irdl_attr_definition
class FastMathFlagsAttr(FastMathAttrBase):
"""
arith.fastmath is a mirror of LLVMs fastmath flags.
"""

name = "arith.fastmath"

def __init__(self, flags: None | Sequence[FastMathFlag] | Literal["none", "fast"]):
# irdl_attr_definition defines an __init__ if none is defined, so we need to
# explicitely define one here.
super().__init__(flags)


@irdl_op_definition
class Constant(IRDLOperation):
Expand Down
3 changes: 2 additions & 1 deletion xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,6 @@ def create(
return UnregisteredOpWithName


@irdl_attr_definition
class UnregisteredAttr(ParametrizedAttribute, ABC):
"""
An unregistered attribute or type.
Expand Down Expand Up @@ -1409,13 +1408,15 @@ def with_name_and_type(cls, name: str, is_type: bool) -> type[UnregisteredAttr]:
`MLContext` to get an `UnregisteredAttr` type.
"""

@irdl_attr_definition
class UnregisteredAttrWithName(UnregisteredAttr):
def verify(self):
if self.attr_name.data != name:
raise VerifyException("Unregistered attribute name mismatch")
if self.is_type.data != int(is_type):
raise VerifyException("Unregistered attribute is_type mismatch")

@irdl_attr_definition
class UnregisteredAttrTypeWithName(UnregisteredAttr, TypeAttribute):
def verify(self):
if self.attr_name.data != name:
Expand Down
15 changes: 13 additions & 2 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from abc import ABC
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from types import EllipsisType
from typing import Annotated, Generic, Literal, TypeVar
Expand Down Expand Up @@ -1118,8 +1119,8 @@ def try_parse(parser: AttrParser) -> set[FastMathFlag] | None:
return None


@irdl_attr_definition
class FastMathAttr(Data[tuple[FastMathFlag, ...]]):
@dataclass(frozen=True)
class FastMathAttrBase(Data[tuple[FastMathFlag, ...]]):
name = "llvm.fastmath"

@property
Expand Down Expand Up @@ -1170,6 +1171,16 @@ def print_parameter(self, printer: Printer):
)


@irdl_attr_definition
class FastMathAttr(FastMathAttrBase):
name = "llvm.fastmath"

def __init__(self, flags: None | Sequence[FastMathFlag] | Literal["none", "fast"]):
# irdl_attr_definition defines an __init__ if none is defined, so we need to
# explicitely define one here.
super().__init__(flags)


@irdl_op_definition
class CallIntrinsicOp(IRDLOperation):
"""
Expand Down
12 changes: 9 additions & 3 deletions xdsl/dialects/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence, Set
from io import StringIO
from typing import IO, Annotated, ClassVar, Generic, TypeAlias, TypeVar
from typing import IO, Annotated, ClassVar, Generic, Literal, TypeAlias, TypeVar

from typing_extensions import Self

Expand All @@ -18,7 +18,7 @@
UnitAttr,
i32,
)
from xdsl.dialects.llvm import FastMathAttr as LLVMFastMathAttr
from xdsl.dialects.llvm import FastMathAttrBase, FastMathFlag
from xdsl.ir import (
Attribute,
Block,
Expand Down Expand Up @@ -60,13 +60,19 @@
from xdsl.utils.hints import isa


class FastMathFlagsAttr(LLVMFastMathAttr):
@irdl_attr_definition
class FastMathFlagsAttr(FastMathAttrBase):
"""
riscv.fastmath is a mirror of LLVMs fastmath flags.
"""

name = "riscv.fastmath"

def __init__(self, flags: None | Sequence[FastMathFlag] | Literal["none", "fast"]):
# irdl_attr_definition defines an __init__ if none is defined, so we need to
# explicitely define one here.
super().__init__(flags)


class RISCVRegisterType(Data[str], TypeAttribute, ABC):
"""
Expand Down
74 changes: 44 additions & 30 deletions xdsl/irdl/irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
get_type_var_from_generic_class,
get_type_var_mapping,
)
from xdsl.utils.runtime_final import runtime_final

if TYPE_CHECKING:
from xdsl.parser import Parser
Expand Down Expand Up @@ -202,7 +203,7 @@ def verify(self, attr: Attribute, constraint_vars: dict[str, Attribute]) -> None


def attr_constr_coercion(
attr: (Attribute | type[Attribute] | AttrConstraint),
attr: Attribute | type[Attribute] | AttrConstraint,
) -> AttrConstraint:
"""
Attributes are coerced into EqAttrConstraints,
Expand Down Expand Up @@ -533,20 +534,24 @@ class IRDLOperation(Operation):
def __init__(
self: IRDLOperation,
*,
operands: Sequence[SSAValue | Operation | Sequence[SSAValue | Operation] | None]
| None = None,
operands: (
Sequence[SSAValue | Operation | Sequence[SSAValue | Operation] | None]
| None
) = None,
result_types: Sequence[Attribute | Sequence[Attribute] | None] | None = None,
properties: Mapping[str, Attribute | None] | None = None,
attributes: Mapping[str, Attribute | None] | None = None,
successors: Sequence[Block | Sequence[Block] | None] | None = None,
regions: Sequence[
Region
regions: (
Sequence[
Region
| None
| Sequence[Operation]
| Sequence[Block]
| Sequence[Region | Sequence[Operation] | Sequence[Block]]
]
| None
| Sequence[Operation]
| Sequence[Block]
| Sequence[Region | Sequence[Operation] | Sequence[Block]]
]
| None = None,
) = None,
):
if operands is None:
operands = []
Expand Down Expand Up @@ -575,20 +580,24 @@ def __init__(
def build(
cls: type[IRDLOperationInvT],
*,
operands: Sequence[SSAValue | Operation | Sequence[SSAValue | Operation] | None]
| None = None,
operands: (
Sequence[SSAValue | Operation | Sequence[SSAValue | Operation] | None]
| None
) = None,
result_types: Sequence[Attribute | Sequence[Attribute] | None] | None = None,
attributes: Mapping[str, Attribute | None] | None = None,
properties: Mapping[str, Attribute | None] | None = None,
successors: Sequence[Block | Sequence[Block] | None] | None = None,
regions: Sequence[
Region
regions: (
Sequence[
Region
| None
| Sequence[Operation]
| Sequence[Block]
| Sequence[Region | Sequence[Operation] | Sequence[Block]]
]
| None
| Sequence[Operation]
| Sequence[Block]
| Sequence[Region | Sequence[Operation] | Sequence[Block]]
]
| None = None,
) = None,
) -> IRDLOperationInvT:
"""Create a new operation using builders."""
op = cls.__new__(cls)
Expand Down Expand Up @@ -1342,10 +1351,9 @@ def wrong_field_exception(field_name: str) -> PyRDLOpDefinitionError:

# Get attribute constraints from a list of pyrdl constraints
def get_constraint(
pyrdl_constr: AttrConstraint
| Attribute
| type[Attribute]
| TypeVar,
pyrdl_constr: (
AttrConstraint | Attribute | type[Attribute] | TypeVar
),
) -> AttrConstraint:
return _irdl_list_to_attr_constraint(
(pyrdl_constr,),
Expand Down Expand Up @@ -2331,8 +2339,12 @@ def get_irdl_definition(cls: type[_PAttrT]):

new_fields["get_irdl_definition"] = get_irdl_definition

return dataclass(frozen=True, init=False)(
type.__new__(type(cls), cls.__name__, (cls,), {**cls.__dict__, **new_fields})
return runtime_final(
dataclass(frozen=True, init=False)(
type.__new__(
type(cls), cls.__name__, (cls,), {**cls.__dict__, **new_fields}
)
)
)


Expand All @@ -2343,11 +2355,13 @@ def irdl_attr_definition(cls: TypeAttributeInvT) -> TypeAttributeInvT:
if issubclass(cls, ParametrizedAttribute):
return irdl_param_attr_definition(cls)
if issubclass(cls, Data):
return dataclass(frozen=True)( # pyright: ignore[reportGeneralTypeIssues]
type(
cls.__name__,
(cls,), # pyright: ignore[reportUnknownArgumentType]
dict(cls.__dict__),
return runtime_final(
dataclass(frozen=True)( # pyright: ignore[reportGeneralTypeIssues]
type(
cls.__name__,
(cls,), # pyright: ignore[reportUnknownArgumentType]
dict(cls.__dict__),
)
)
)
raise TypeError(
Expand Down
26 changes: 26 additions & 0 deletions xdsl/utils/runtime_final.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Any, TypeVar


def _init_subclass(cls: type, *args: Any, **kwargs: Any) -> None:
"""Is used by `final` to prevent a class from being subclassed at runtime."""
raise TypeError("Subclassing final classes is restricted")


C = TypeVar("C", bound=type)


def runtime_final(cls: C) -> C:
"""Prevent a class from being subclassed at runtime."""

# It is safe to discard the previous __init_subclass__ method as anyway
# the new one will raise an error.
setattr(cls, "__init_subclass__", classmethod(_init_subclass))

# This is a marker to check if a class is final or not.
setattr(cls, "__final__", True)
return cls


def is_runtime_final(cls: type) -> bool:
"""Check if a class is final."""
return hasattr(cls, "__final__")

0 comments on commit 137aee9

Please sign in to comment.