diff --git a/tests/test_declarative_assembly_format.py b/tests/test_declarative_assembly_format.py index b94da4763d..35ff3a3e22 100644 --- a/tests/test_declarative_assembly_format.py +++ b/tests/test_declarative_assembly_format.py @@ -26,6 +26,7 @@ ParameterDef, VarOperand, VarOpResult, + attr_def, irdl_attr_definition, irdl_op_definition, operand_def, @@ -245,6 +246,49 @@ class PropOp(IRDLOperation): check_equivalence(program, generic_program, ctx) +################################################################################ +# Attribute variables # +################################################################################ + + +@irdl_op_definition +class OpWithAttr(IRDLOperation): + name = "test.one_attr" + + attr = attr_def(Attribute) + assembly_format = "$attr attr-dict" + + +@pytest.mark.parametrize( + "program, generic_program", + [ + ("test.one_attr i32", '"test.one_attr"() {"attr" = i32} : () -> ()'), + ( + 'test.one_attr i32 {"attr2" = i64}', + '"test.one_attr"() {"attr" = i32, "attr2" = i64} : () -> ()', + ), + ], +) +def test_standard_attr_directive(program: str, generic_program: str): + ctx = MLContext() + ctx.load_op(OpWithAttr) + + check_equivalence(program, generic_program, ctx) + check_roundtrip(program, ctx) + + +def test_attr_variable_shadowed(): + ctx = MLContext() + ctx.load_op(OpWithAttr) + + parser = Parser(ctx, "test.one_attr i32 {attr = i64}") + with pytest.raises( + ParseError, + match="attributes attr are defined in other parts", + ): + parser.parse_operation() + + ################################################################################ # Punctuations, keywords, and whitespaces # ################################################################################ diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index ea86af0409..1db14eccc6 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -43,12 +43,9 @@ class ParsingState: constraint_variables: dict[str, Attribute] def __init__(self, op_def: OpDef): - has_attributes = bool(op_def.attributes) and list(op_def.attributes.keys()) != [ - "operandSegmentSizes" - ] - if has_attributes or op_def.regions or op_def.successors: + if op_def.regions or op_def.successors: raise NotImplementedError( - "Operation definitions with attributes, regions, " + "Operation definitions with regions " "or successors are not yet supported" ) self.operands = [None] * len(op_def.operands) @@ -265,6 +262,13 @@ class AttrDictDirective(FormatDirective): with_keyword: bool """If this is set, the format starts with the `attributes` keyword.""" + reserved_attr_names: set[str] + """ + The set of attributes that should not be printed. + These attributes are printed in other places in the format, and thus would be + printed twice otherwise. + """ + def parse(self, parser: Parser, state: ParsingState) -> None: if self.with_keyword: res = parser.parse_optional_attr_dict_with_keyword() @@ -274,20 +278,27 @@ def parse(self, parser: Parser, state: ParsingState) -> None: res = res.data else: res = parser.parse_optional_attr_dict() - state.attributes = res + defined_reserved_keys = self.reserved_attr_names & res.keys() + if defined_reserved_keys: + parser.raise_error( + f"attributes {', '.join(defined_reserved_keys)} are defined in other parts of the " + "assembly format, and thus should not be defined in the attribute " + "dictionary." + ) + state.attributes |= res def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if not op.attributes and not op.properties: return - if self.with_keyword: - printer.print(" attributes") if any(name in op.attributes for name in op.properties): raise ValueError( "Cannot print attributes and properties with the same name" "in a signle dictionary" ) printer.print_op_attributes( - op.attributes | op.properties, reserved_attr_names=["operandSegmentSizes"] + op.attributes | op.properties, + reserved_attr_names=self.reserved_attr_names, + print_keyword=self.with_keyword, ) state.last_was_punctuation = False state.should_emit_space = False @@ -297,7 +308,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No class OperandVariable(FormatDirective): """ An operand variable, with the following format: - operand-directive ::= percent-ident + operand-directive ::= dollar-ident The directive will request a space to be printed after. """ @@ -344,7 +355,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No class OperandTypeDirective(FormatDirective): """ An operand variable type directive, with the following format: - operand-type-directive ::= type(percent-ident) + operand-type-directive ::= type(dollar-ident) The directive will request a space to be printed right after. """ @@ -393,7 +404,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No class ResultVariable(FormatDirective): """ An result variable, with the following format: - result-directive ::= percent-ident + result-directive ::= dollar-ident This directive can not be used for parsing and printing directly, as result parsing is not handled by the custom operation parser. """ @@ -442,7 +453,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No class ResultTypeDirective(FormatDirective): """ A result variable type directive, with the following format: - result-type-directive ::= type(percent-ident) + result-type-directive ::= type(dollar-ident) The directive will request a space to be printed right after. """ @@ -463,6 +474,29 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.should_emit_space = True +@dataclass(frozen=True) +class AttributeVariable(FormatDirective): + """ + An attribute variable, with the following format: + result-directive ::= dollar-ident + The directive will request a space to be printed right after. + """ + + attr_name: str + """The attribute name as it should be in the attribute or property dictionary.""" + + def parse(self, parser: Parser, state: ParsingState) -> None: + attribute = parser.parse_attribute() + state.attributes[self.attr_name] = attribute + + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + state.should_emit_space = True + state.last_was_punctuation = False + printer.print_attribute(op.attributes[self.attr_name]) + + @dataclass(frozen=True) class VariadicResultTypeDirective(ResultTypeDirective): """ diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index 7207a9820a..2b311977d6 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -11,9 +11,10 @@ from enum import Enum, auto from xdsl.ir import Attribute -from xdsl.irdl import OpDef, VariadicDef +from xdsl.irdl import AttrSizedOperandSegments, OpDef, VariadicDef from xdsl.irdl.declarative_assembly_format import ( AttrDictDirective, + AttributeVariable, FormatDirective, FormatProgram, KeywordDirective, @@ -89,6 +90,8 @@ class FormatParser(BaseParser): """The operand types that are already parsed.""" seen_result_types: list[bool] """The result types that are already parsed.""" + seen_attributes: set[str] + """The attributes that are already parsed.""" has_attr_dict: bool = field(default=False) """True if the attribute dictionary has already been parsed.""" context: ParsingContext = field(default=ParsingContext.TopLevel) @@ -105,6 +108,7 @@ def __init__(self, input: str, op_def: OpDef): self.seen_operands = [False] * len(op_def.operands) self.seen_operand_types = [False] * len(op_def.operands) self.seen_result_types = [False] * len(op_def.results) + self.seen_attributes = set[str]() self.type_resolutions = {} def parse_format(self) -> FormatProgram: @@ -133,12 +137,27 @@ def parse_format(self) -> FormatProgram: "A variadic directive cannot be followed by a comma literal." ) + self.add_reserved_attrs_to_directive(elements) seen_variables = self.resolve_types() self.verify_attr_dict() self.verify_operands(seen_variables) self.verify_results(seen_variables) return FormatProgram(elements) + def add_reserved_attrs_to_directive(self, elements: list[FormatDirective]): + """ + Add reserved attributes to the attr-dict directive. + These are the attributes that are printed/parsed in other places in the format, + and thus should not be printed in the attr-dict directive. + """ + for idx, element in enumerate(elements): + if isinstance(element, AttrDictDirective): + elements[idx] = AttrDictDirective( + with_keyword=element.with_keyword, + reserved_attr_names=self.seen_attributes, + ) + return + def resolve_types(self) -> set[str]: """ Find out which constraint variables can be inferred from the parsed attributes. @@ -203,7 +222,9 @@ def verify_attr_dict(self): if not self.has_attr_dict: self.raise_error("'attr-dict' directive not found") - def parse_optional_variable(self) -> OperandVariable | ResultVariable | None: + def parse_optional_variable( + self, + ) -> OperandVariable | ResultVariable | AttributeVariable | None: """ Parse a variable, if present, with the following format: variable ::= `$` bare-ident @@ -223,6 +244,8 @@ def parse_optional_variable(self) -> OperandVariable | ResultVariable | None: if self.seen_operands[idx]: self.raise_error(f"operand '{variable_name}' is already bound") self.seen_operands[idx] = True + if isinstance(operand_def, VariadicDef): + self.seen_attributes.add(AttrSizedOperandSegments.attribute_name) if isinstance(operand_def, VariadicDef): return VariadicOperandVariable(variable_name, idx) else: @@ -242,6 +265,17 @@ def parse_optional_variable(self) -> OperandVariable | ResultVariable | None: else: return ResultVariable(variable_name, idx) + # Check if the variable is an attribute + if variable_name in self.op_def.accessor_names: + (attr_name, attr_or_prop) = self.op_def.accessor_names[variable_name] + if attr_or_prop == "property": + self.raise_error("properties are currently not supported") + if self.context == ParsingContext.TopLevel: + if attr_name in self.seen_attributes: + self.raise_error(f"attribute '{variable_name}' is already bound") + self.seen_attributes.add(attr_name) + return AttributeVariable(variable_name) + self.raise_error( "expected variable to refer to an operand, " "attribute, region, result, or successor" @@ -283,6 +317,8 @@ def parse_type_directive(self) -> FormatDirective: self.raise_error(f"type of '{name}' is already bound") self.seen_result_types[index] = True res = ResultTypeDirective(name, index) + case AttributeVariable(): + self.raise_error("can only take the type of an operand or result") self.parse_punctuation(")") self.context = previous_context @@ -358,4 +394,6 @@ def create_attr_dict_directive(self, with_keyword: bool) -> AttrDictDirective: "in the assembly format description" ) self.has_attr_dict = True - return AttrDictDirective(with_keyword=with_keyword) + # reserved_attr_names is populated once the format is parsed, as some attributes + # might appear after the attr-dict directive + return AttrDictDirective(with_keyword=with_keyword, reserved_attr_names=set())