diff --git a/tests/filecheck/dialects/gpu/invalid.mlir b/tests/filecheck/dialects/gpu/invalid.mlir index a7cfd767c3..6df9ab849c 100644 --- a/tests/filecheck/dialects/gpu/invalid.mlir +++ b/tests/filecheck/dialects/gpu/invalid.mlir @@ -19,7 +19,7 @@ "builtin.module"() ({ }) {"wrong_all_reduce_operation" = #gpu}: () -> () -// CHECK: Expected `add`, `and`, `max`, `min`, `mul`, `or` or `xor`. +// CHECK: Expected `add`, `and`, `max`, `min`, `mul`, `or`, or `xor`. // ----- @@ -77,7 +77,7 @@ "builtin.module"() ({ }) {"wrong_dim" = #gpu}: () -> () -// CHECK: Expected `x`, `y` or `z`. +// CHECK: Expected `x`, `y`, or `z`. // ----- diff --git a/tests/test_parser.py b/tests/test_parser.py index 2930556bca..f993f66dc0 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -28,6 +28,7 @@ from xdsl.printer import Printer from xdsl.utils.exceptions import ParseError, VerifyException from xdsl.utils.lexer import Token +from xdsl.utils.str_enum import StrEnum # pyright: reportPrivateUsage=false @@ -882,3 +883,39 @@ def test_parse_visibility(keyword: str, expected: StringAttr | None): parser.parse_visibility_keyword() else: assert parser.parse_visibility_keyword() == expected + + +class MyEnum(StrEnum): + A = "a" + B = "b" + C = "c" + + +@pytest.mark.parametrize( + "keyword, expected", + [ + ("a", MyEnum.A), + ("b", MyEnum.B), + ("c", MyEnum.C), + ("cc", None), + ], +) +def test_parse_str_enum(keyword: str, expected: MyEnum | None): + assert Parser(MLContext(), keyword).parse_optional_str_enum(MyEnum) == expected + + parser = Parser(MLContext(), keyword) + if expected is None: + with pytest.raises(ParseError, match="Expected `a`, `b`, or `c`"): + parser.parse_str_enum(MyEnum) + else: + assert parser.parse_str_enum(MyEnum) == expected + + +class MySingletonEnum(StrEnum): + A = "a" + + +def test_parse_singleton_enum_fail(): + parser = Parser(MLContext(), "b") + with pytest.raises(ParseError, match="Expected `a`"): + parser.parse_str_enum(MySingletonEnum) diff --git a/xdsl/ir/core.py b/xdsl/ir/core.py index cd2a7b2449..5bc8e95529 100644 --- a/xdsl/ir/core.py +++ b/xdsl/ir/core.py @@ -608,17 +608,7 @@ def print_parameter(self, printer: Printer) -> None: @classmethod def parse_parameter(cls, parser: AttrParser) -> EnumType: - enum_type = cls.enum_type - - val = parser.parse_identifier() - if val not in enum_type.__members__.values(): - enum_values = list(enum_type) - if len(enum_values) == 1: - parser.raise_error(f"Expected `{enum_values[0]}`.") - parser.raise_error( - f"Expected `{'`, `'.join(enum_values[:-1])}` or `{enum_values[-1]}`." - ) - return cast(EnumType, enum_type(val)) + return cast(EnumType, parser.parse_str_enum(cls.enum_type)) @dataclass(frozen=True, init=False) diff --git a/xdsl/parser/base_parser.py b/xdsl/parser/base_parser.py index f34b4c355e..aa702db9fe 100644 --- a/xdsl/parser/base_parser.py +++ b/xdsl/parser/base_parser.py @@ -11,6 +11,7 @@ from xdsl.utils.exceptions import ParseError from xdsl.utils.lexer import Lexer, Position, Span, StringLiteral, Token +from xdsl.utils.str_enum import StrEnum @dataclass(init=False) @@ -34,6 +35,7 @@ def __init__(self, lexer: Lexer, dialect_stack: list[str] | None = None): _AnyInvT = TypeVar("_AnyInvT") +_EnumType = TypeVar("_EnumType", bound=StrEnum) @dataclass @@ -534,3 +536,28 @@ def parse_punctuation( kind = Token.Kind.get_punctuation_kind_from_spelling(punctuation) self._parse_token(kind, f"Expected '{punctuation}'" + context_msg) return punctuation + + def parse_str_enum(self, enum_type: type[_EnumType]) -> _EnumType: + """Parse a string enum value.""" + result = self.parse_optional_str_enum(enum_type) + if result is not None: + return result + enum_values = tuple(enum_type) + if len(enum_values) == 1: + self.raise_error(f"Expected `{enum_values[0]}`.") + self.raise_error( + f"Expected `{'`, `'.join(enum_values[:-1])}`, or `{enum_values[-1]}`." + ) + + def parse_optional_str_enum(self, enum_type: type[_EnumType]) -> _EnumType | None: + """Parse a string enum value, if present.""" + + if self._current_token.kind != Token.Kind.BARE_IDENT: + return None + + val = self._current_token.text + if val not in enum_type.__members__.values(): + return None + + self._consume_token(Token.Kind.BARE_IDENT) + return enum_type(val)