Skip to content

Commit

Permalink
core: Allow to parse string enums (#2696)
Browse files Browse the repository at this point in the history
Add support in the parser to parse an `StrEnum`.
It parses a bare identifier, and check if it is part of the `StrEnum`.
  • Loading branch information
math-fehr authored Jun 11, 2024
1 parent 43a625a commit 9ca5a64
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 13 deletions.
4 changes: 2 additions & 2 deletions tests/filecheck/dialects/gpu/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"builtin.module"() ({
}) {"wrong_all_reduce_operation" = #gpu<all_reduce_op magic>}: () -> ()

// CHECK: Expected `add`, `and`, `max`, `min`, `mul`, `or` or `xor`.
// CHECK: Expected `add`, `and`, `max`, `min`, `mul`, `or`, or `xor`.

// -----

Expand Down Expand Up @@ -77,7 +77,7 @@
"builtin.module"() ({
}) {"wrong_dim" = #gpu<dim w>}: () -> ()

// CHECK: Expected `x`, `y` or `z`.
// CHECK: Expected `x`, `y`, or `z`.

// -----

Expand Down
37 changes: 37 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from xdsl.printer import Printer
from xdsl.utils.exceptions import ParseError, VerifyException
from xdsl.utils.lexer import Token
from xdsl.utils.str_enum import StrEnum

# pyright: reportPrivateUsage=false

Expand Down Expand Up @@ -882,3 +883,39 @@ def test_parse_visibility(keyword: str, expected: StringAttr | None):
parser.parse_visibility_keyword()
else:
assert parser.parse_visibility_keyword() == expected


class MyEnum(StrEnum):
A = "a"
B = "b"
C = "c"


@pytest.mark.parametrize(
"keyword, expected",
[
("a", MyEnum.A),
("b", MyEnum.B),
("c", MyEnum.C),
("cc", None),
],
)
def test_parse_str_enum(keyword: str, expected: MyEnum | None):
assert Parser(MLContext(), keyword).parse_optional_str_enum(MyEnum) == expected

parser = Parser(MLContext(), keyword)
if expected is None:
with pytest.raises(ParseError, match="Expected `a`, `b`, or `c`"):
parser.parse_str_enum(MyEnum)
else:
assert parser.parse_str_enum(MyEnum) == expected


class MySingletonEnum(StrEnum):
A = "a"


def test_parse_singleton_enum_fail():
parser = Parser(MLContext(), "b")
with pytest.raises(ParseError, match="Expected `a`"):
parser.parse_str_enum(MySingletonEnum)
12 changes: 1 addition & 11 deletions xdsl/ir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,17 +608,7 @@ def print_parameter(self, printer: Printer) -> None:

@classmethod
def parse_parameter(cls, parser: AttrParser) -> EnumType:
enum_type = cls.enum_type

val = parser.parse_identifier()
if val not in enum_type.__members__.values():
enum_values = list(enum_type)
if len(enum_values) == 1:
parser.raise_error(f"Expected `{enum_values[0]}`.")
parser.raise_error(
f"Expected `{'`, `'.join(enum_values[:-1])}` or `{enum_values[-1]}`."
)
return cast(EnumType, enum_type(val))
return cast(EnumType, parser.parse_str_enum(cls.enum_type))


@dataclass(frozen=True, init=False)
Expand Down
27 changes: 27 additions & 0 deletions xdsl/parser/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from xdsl.utils.exceptions import ParseError
from xdsl.utils.lexer import Lexer, Position, Span, StringLiteral, Token
from xdsl.utils.str_enum import StrEnum


@dataclass(init=False)
Expand All @@ -34,6 +35,7 @@ def __init__(self, lexer: Lexer, dialect_stack: list[str] | None = None):


_AnyInvT = TypeVar("_AnyInvT")
_EnumType = TypeVar("_EnumType", bound=StrEnum)


@dataclass
Expand Down Expand Up @@ -534,3 +536,28 @@ def parse_punctuation(
kind = Token.Kind.get_punctuation_kind_from_spelling(punctuation)
self._parse_token(kind, f"Expected '{punctuation}'" + context_msg)
return punctuation

def parse_str_enum(self, enum_type: type[_EnumType]) -> _EnumType:
"""Parse a string enum value."""
result = self.parse_optional_str_enum(enum_type)
if result is not None:
return result
enum_values = tuple(enum_type)
if len(enum_values) == 1:
self.raise_error(f"Expected `{enum_values[0]}`.")
self.raise_error(
f"Expected `{'`, `'.join(enum_values[:-1])}`, or `{enum_values[-1]}`."
)

def parse_optional_str_enum(self, enum_type: type[_EnumType]) -> _EnumType | None:
"""Parse a string enum value, if present."""

if self._current_token.kind != Token.Kind.BARE_IDENT:
return None

val = self._current_token.text
if val not in enum_type.__members__.values():
return None

self._consume_token(Token.Kind.BARE_IDENT)
return enum_type(val)

0 comments on commit 9ca5a64

Please sign in to comment.