Skip to content

Commit

Permalink
core: Add attribute variables to the declarative assembly format. (#2028
Browse files Browse the repository at this point in the history
)

Add attribute variables to the declarative assembly format.
This does not support properties yet, this will be added in a following
PR.

The attributes that are parsed by the declarative assembly format are
not
parsed/printed back in the attr-dict directive.
  • Loading branch information
math-fehr authored Jan 26, 2024
1 parent 74b44e5 commit e743222
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 16 deletions.
44 changes: 44 additions & 0 deletions tests/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ParameterDef,
VarOperand,
VarOpResult,
attr_def,
irdl_attr_definition,
irdl_op_definition,
operand_def,
Expand Down Expand Up @@ -245,6 +246,49 @@ class PropOp(IRDLOperation):
check_equivalence(program, generic_program, ctx)


################################################################################
# Attribute variables #
################################################################################


@irdl_op_definition
class OpWithAttr(IRDLOperation):
name = "test.one_attr"

attr = attr_def(Attribute)
assembly_format = "$attr attr-dict"


@pytest.mark.parametrize(
"program, generic_program",
[
("test.one_attr i32", '"test.one_attr"() {"attr" = i32} : () -> ()'),
(
'test.one_attr i32 {"attr2" = i64}',
'"test.one_attr"() {"attr" = i32, "attr2" = i64} : () -> ()',
),
],
)
def test_standard_attr_directive(program: str, generic_program: str):
ctx = MLContext()
ctx.load_op(OpWithAttr)

check_equivalence(program, generic_program, ctx)
check_roundtrip(program, ctx)


def test_attr_variable_shadowed():
ctx = MLContext()
ctx.load_op(OpWithAttr)

parser = Parser(ctx, "test.one_attr i32 {attr = i64}")
with pytest.raises(
ParseError,
match="attributes attr are defined in other parts",
):
parser.parse_operation()


################################################################################
# Punctuations, keywords, and whitespaces #
################################################################################
Expand Down
60 changes: 47 additions & 13 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,9 @@ class ParsingState:
constraint_variables: dict[str, Attribute]

def __init__(self, op_def: OpDef):
has_attributes = bool(op_def.attributes) and list(op_def.attributes.keys()) != [
"operandSegmentSizes"
]
if has_attributes or op_def.regions or op_def.successors:
if op_def.regions or op_def.successors:
raise NotImplementedError(
"Operation definitions with attributes, regions, "
"Operation definitions with regions "
"or successors are not yet supported"
)
self.operands = [None] * len(op_def.operands)
Expand Down Expand Up @@ -265,6 +262,13 @@ class AttrDictDirective(FormatDirective):
with_keyword: bool
"""If this is set, the format starts with the `attributes` keyword."""

reserved_attr_names: set[str]
"""
The set of attributes that should not be printed.
These attributes are printed in other places in the format, and thus would be
printed twice otherwise.
"""

def parse(self, parser: Parser, state: ParsingState) -> None:
if self.with_keyword:
res = parser.parse_optional_attr_dict_with_keyword()
Expand All @@ -274,20 +278,27 @@ def parse(self, parser: Parser, state: ParsingState) -> None:
res = res.data
else:
res = parser.parse_optional_attr_dict()
state.attributes = res
defined_reserved_keys = self.reserved_attr_names & res.keys()
if defined_reserved_keys:
parser.raise_error(
f"attributes {', '.join(defined_reserved_keys)} are defined in other parts of the "
"assembly format, and thus should not be defined in the attribute "
"dictionary."
)
state.attributes |= res

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if not op.attributes and not op.properties:
return
if self.with_keyword:
printer.print(" attributes")
if any(name in op.attributes for name in op.properties):
raise ValueError(
"Cannot print attributes and properties with the same name"
"in a signle dictionary"
)
printer.print_op_attributes(
op.attributes | op.properties, reserved_attr_names=["operandSegmentSizes"]
op.attributes | op.properties,
reserved_attr_names=self.reserved_attr_names,
print_keyword=self.with_keyword,
)
state.last_was_punctuation = False
state.should_emit_space = False
Expand All @@ -297,7 +308,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No
class OperandVariable(FormatDirective):
"""
An operand variable, with the following format:
operand-directive ::= percent-ident
operand-directive ::= dollar-ident
The directive will request a space to be printed after.
"""

Expand Down Expand Up @@ -344,7 +355,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No
class OperandTypeDirective(FormatDirective):
"""
An operand variable type directive, with the following format:
operand-type-directive ::= type(percent-ident)
operand-type-directive ::= type(dollar-ident)
The directive will request a space to be printed right after.
"""

Expand Down Expand Up @@ -393,7 +404,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No
class ResultVariable(FormatDirective):
"""
An result variable, with the following format:
result-directive ::= percent-ident
result-directive ::= dollar-ident
This directive can not be used for parsing and printing directly, as result
parsing is not handled by the custom operation parser.
"""
Expand Down Expand Up @@ -442,7 +453,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No
class ResultTypeDirective(FormatDirective):
"""
A result variable type directive, with the following format:
result-type-directive ::= type(percent-ident)
result-type-directive ::= type(dollar-ident)
The directive will request a space to be printed right after.
"""

Expand All @@ -463,6 +474,29 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No
state.should_emit_space = True


@dataclass(frozen=True)
class AttributeVariable(FormatDirective):
"""
An attribute variable, with the following format:
result-directive ::= dollar-ident
The directive will request a space to be printed right after.
"""

attr_name: str
"""The attribute name as it should be in the attribute or property dictionary."""

def parse(self, parser: Parser, state: ParsingState) -> None:
attribute = parser.parse_attribute()
state.attributes[self.attr_name] = attribute

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
state.should_emit_space = True
state.last_was_punctuation = False
printer.print_attribute(op.attributes[self.attr_name])


@dataclass(frozen=True)
class VariadicResultTypeDirective(ResultTypeDirective):
"""
Expand Down
44 changes: 41 additions & 3 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from enum import Enum, auto

from xdsl.ir import Attribute
from xdsl.irdl import OpDef, VariadicDef
from xdsl.irdl import AttrSizedOperandSegments, OpDef, VariadicDef
from xdsl.irdl.declarative_assembly_format import (
AttrDictDirective,
AttributeVariable,
FormatDirective,
FormatProgram,
KeywordDirective,
Expand Down Expand Up @@ -89,6 +90,8 @@ class FormatParser(BaseParser):
"""The operand types that are already parsed."""
seen_result_types: list[bool]
"""The result types that are already parsed."""
seen_attributes: set[str]
"""The attributes that are already parsed."""
has_attr_dict: bool = field(default=False)
"""True if the attribute dictionary has already been parsed."""
context: ParsingContext = field(default=ParsingContext.TopLevel)
Expand All @@ -105,6 +108,7 @@ def __init__(self, input: str, op_def: OpDef):
self.seen_operands = [False] * len(op_def.operands)
self.seen_operand_types = [False] * len(op_def.operands)
self.seen_result_types = [False] * len(op_def.results)
self.seen_attributes = set[str]()
self.type_resolutions = {}

def parse_format(self) -> FormatProgram:
Expand Down Expand Up @@ -133,12 +137,27 @@ def parse_format(self) -> FormatProgram:
"A variadic directive cannot be followed by a comma literal."
)

self.add_reserved_attrs_to_directive(elements)
seen_variables = self.resolve_types()
self.verify_attr_dict()
self.verify_operands(seen_variables)
self.verify_results(seen_variables)
return FormatProgram(elements)

def add_reserved_attrs_to_directive(self, elements: list[FormatDirective]):
"""
Add reserved attributes to the attr-dict directive.
These are the attributes that are printed/parsed in other places in the format,
and thus should not be printed in the attr-dict directive.
"""
for idx, element in enumerate(elements):
if isinstance(element, AttrDictDirective):
elements[idx] = AttrDictDirective(
with_keyword=element.with_keyword,
reserved_attr_names=self.seen_attributes,
)
return

def resolve_types(self) -> set[str]:
"""
Find out which constraint variables can be inferred from the parsed attributes.
Expand Down Expand Up @@ -203,7 +222,9 @@ def verify_attr_dict(self):
if not self.has_attr_dict:
self.raise_error("'attr-dict' directive not found")

def parse_optional_variable(self) -> OperandVariable | ResultVariable | None:
def parse_optional_variable(
self,
) -> OperandVariable | ResultVariable | AttributeVariable | None:
"""
Parse a variable, if present, with the following format:
variable ::= `$` bare-ident
Expand All @@ -223,6 +244,8 @@ def parse_optional_variable(self) -> OperandVariable | ResultVariable | None:
if self.seen_operands[idx]:
self.raise_error(f"operand '{variable_name}' is already bound")
self.seen_operands[idx] = True
if isinstance(operand_def, VariadicDef):
self.seen_attributes.add(AttrSizedOperandSegments.attribute_name)
if isinstance(operand_def, VariadicDef):
return VariadicOperandVariable(variable_name, idx)
else:
Expand All @@ -242,6 +265,17 @@ def parse_optional_variable(self) -> OperandVariable | ResultVariable | None:
else:
return ResultVariable(variable_name, idx)

# Check if the variable is an attribute
if variable_name in self.op_def.accessor_names:
(attr_name, attr_or_prop) = self.op_def.accessor_names[variable_name]
if attr_or_prop == "property":
self.raise_error("properties are currently not supported")
if self.context == ParsingContext.TopLevel:
if attr_name in self.seen_attributes:
self.raise_error(f"attribute '{variable_name}' is already bound")
self.seen_attributes.add(attr_name)
return AttributeVariable(variable_name)

self.raise_error(
"expected variable to refer to an operand, "
"attribute, region, result, or successor"
Expand Down Expand Up @@ -283,6 +317,8 @@ def parse_type_directive(self) -> FormatDirective:
self.raise_error(f"type of '{name}' is already bound")
self.seen_result_types[index] = True
res = ResultTypeDirective(name, index)
case AttributeVariable():
self.raise_error("can only take the type of an operand or result")

self.parse_punctuation(")")
self.context = previous_context
Expand Down Expand Up @@ -358,4 +394,6 @@ def create_attr_dict_directive(self, with_keyword: bool) -> AttrDictDirective:
"in the assembly format description"
)
self.has_attr_dict = True
return AttrDictDirective(with_keyword=with_keyword)
# reserved_attr_names is populated once the format is parsed, as some attributes
# might appear after the attr-dict directive
return AttrDictDirective(with_keyword=with_keyword, reserved_attr_names=set())

0 comments on commit e743222

Please sign in to comment.