Skip to content

Commit

Permalink
core: properties in declarative format (#2031)
Browse files Browse the repository at this point in the history
This PR enables properties in the declarative format.

By default, all properties should be specified in the format.
However, with the `ParsePropInAttrDict` IRDL option, properties are
instead parsed by the attribute dictionary, to keep compatibility with
MLIR.
  • Loading branch information
math-fehr authored Jan 30, 2024
1 parent 6e0975f commit 1ee55fd
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 25 deletions.
43 changes: 43 additions & 0 deletions tests/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
EqAttrConstraint,
IRDLOperation,
ParameterDef,
ParsePropInAttrDict,
VarOperand,
VarOpResult,
attr_def,
irdl_attr_definition,
irdl_op_definition,
operand_def,
opt_prop_def,
prop_def,
result_def,
var_operand_def,
var_result_def,
Expand Down Expand Up @@ -237,6 +239,7 @@ def test_attr_dict_prop_fallack(program: str, generic_program: str):
class PropOp(IRDLOperation):
name = "test.prop"
prop = opt_prop_def(Attribute)
irdl_options = [ParsePropInAttrDict()]
assembly_format = "attr-dict"

ctx = MLContext()
Expand Down Expand Up @@ -289,6 +292,46 @@ def test_attr_variable_shadowed():
parser.parse_operation()


def test_missing_property_error():
class OpWithMissingProp(IRDLOperation):
name = "test.missing_prop"

prop1 = prop_def(Attribute)
prop2 = prop_def(Attribute)
assembly_format = "$prop1 attr-dict"

with pytest.raises(
PyRDLOpDefinitionError,
match="prop2 properties are missing",
):
irdl_op_definition(OpWithMissingProp)


@pytest.mark.parametrize(
"program, generic_program",
[
("test.one_prop i32", '"test.one_prop"() <{"prop" = i32}> : () -> ()'),
(
'test.one_prop i32 {"attr2" = i64}',
'"test.one_prop"() <{"prop" = i32}> {"attr2" = i64} : () -> ()',
),
],
)
def test_standard_prop_directive(program: str, generic_program: str):
@irdl_op_definition
class OpWithProp(IRDLOperation):
name = "test.one_prop"

prop = prop_def(Attribute)
assembly_format = "$prop attr-dict"

ctx = MLContext()
ctx.load_op(OpWithProp)

check_equivalence(program, generic_program, ctx)
check_roundtrip(program, ctx)


################################################################################
# Punctuations, keywords, and whitespaces #
################################################################################
Expand Down
3 changes: 3 additions & 0 deletions xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
ConstraintVar,
IRDLOperation,
Operand,
ParsePropInAttrDict,
VarOperand,
irdl_op_definition,
operand_def,
Expand Down Expand Up @@ -71,6 +72,7 @@ class Load(IRDLOperation):
indices: VarOperand = var_operand_def(IndexType())
res: OpResult = result_def(T)

irdl_options = [ParsePropInAttrDict()]
assembly_format = "$memref `[` $indices `]` attr-dict `:` type($memref)"

# TODO varargs for indexing, which must match the memref dimensions
Expand Down Expand Up @@ -109,6 +111,7 @@ class Store(IRDLOperation):
memref: Operand = operand_def(MemRefType[T])
indices: VarOperand = var_operand_def(IndexType())

irdl_options = [ParsePropInAttrDict()]
assembly_format = "$value `,` $memref `[` $indices `]` attr-dict `:` type($memref)"

def verify_(self):
Expand Down
62 changes: 48 additions & 14 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,13 @@ def parse(
"Single operand has no type or variadic/optional type"
operands.append(parser.resolve_operand(uo, ot))

properties = op_def.split_properties(state.attributes)
# Get the properties from the attribute dictionary if no properties are
# defined. This is necessary to be compatible with MLIR format, such as
# `memref.load`.
if state.properties:
properties = state.properties
else:
properties = op_def.split_properties(state.attributes)
return op_type.build(
result_types=result_types,
operands=operands,
Expand Down Expand Up @@ -269,6 +275,12 @@ class AttrDictDirective(FormatDirective):
printed twice otherwise.
"""

print_properties: bool
"""
If this is set, also print properties as part of the attribute dictionary.
This is used to keep compatibility with MLIR which allows that.
"""

def parse(self, parser: Parser, state: ParsingState) -> None:
if self.with_keyword:
res = parser.parse_optional_attr_dict_with_keyword()
Expand All @@ -288,18 +300,32 @@ def parse(self, parser: Parser, state: ParsingState) -> None:
state.attributes |= res

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if not op.attributes and not op.properties:
return
if any(name in op.attributes for name in op.properties):
raise ValueError(
"Cannot print attributes and properties with the same name"
"in a signle dictionary"
if self.print_properties:
if (
not (set(op.attributes.keys()) | set(op.properties.keys()))
- self.reserved_attr_names
):
return
if any(name in op.attributes for name in op.properties):
raise ValueError(
"Cannot print attributes and properties with the same name "
"in a signle dictionary"
)
printer.print_op_attributes(
op.attributes | op.properties,
reserved_attr_names=self.reserved_attr_names,
print_keyword=self.with_keyword,
)
printer.print_op_attributes(
op.attributes | op.properties,
reserved_attr_names=self.reserved_attr_names,
print_keyword=self.with_keyword,
)
else:
if not set(op.attributes.keys()) - self.reserved_attr_names:
return
printer.print_op_attributes(
op.attributes,
reserved_attr_names=self.reserved_attr_names,
print_keyword=self.with_keyword,
)

# This is changed only if something was printed
state.last_was_punctuation = False
state.should_emit_space = False

Expand Down Expand Up @@ -484,17 +510,25 @@ class AttributeVariable(FormatDirective):

attr_name: str
"""The attribute name as it should be in the attribute or property dictionary."""
is_property: bool
"""Should this attribute be put in the attribute or property dictionary."""

def parse(self, parser: Parser, state: ParsingState) -> None:
attribute = parser.parse_attribute()
state.attributes[self.attr_name] = attribute
if self.is_property:
state.properties[self.attr_name] = attribute
else:
state.attributes[self.attr_name] = attribute

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
state.should_emit_space = True
state.last_was_punctuation = False
printer.print_attribute(op.attributes[self.attr_name])
if self.is_property:
printer.print_attribute(op.properties[self.attr_name])
else:
printer.print_attribute(op.attributes[self.attr_name])


@dataclass(frozen=True)
Expand Down
58 changes: 50 additions & 8 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from enum import Enum, auto

from xdsl.ir import Attribute
from xdsl.irdl import AttrSizedOperandSegments, OpDef, VariadicDef
from xdsl.irdl import AttrSizedOperandSegments, OpDef, ParsePropInAttrDict, VariadicDef
from xdsl.irdl.declarative_assembly_format import (
AttrDictDirective,
AttributeVariable,
Expand Down Expand Up @@ -92,6 +92,8 @@ class FormatParser(BaseParser):
"""The result types that are already parsed."""
seen_attributes: set[str]
"""The attributes that are already parsed."""
seen_properties: set[str]
"""The properties that are already parsed."""
has_attr_dict: bool = field(default=False)
"""True if the attribute dictionary has already been parsed."""
context: ParsingContext = field(default=ParsingContext.TopLevel)
Expand All @@ -109,6 +111,7 @@ def __init__(self, input: str, op_def: OpDef):
self.seen_operand_types = [False] * len(op_def.operands)
self.seen_result_types = [False] * len(op_def.results)
self.seen_attributes = set[str]()
self.seen_properties = set[str]()
self.type_resolutions = {}

def parse_format(self) -> FormatProgram:
Expand Down Expand Up @@ -140,6 +143,7 @@ def parse_format(self) -> FormatProgram:
self.add_reserved_attrs_to_directive(elements)
seen_variables = self.resolve_types()
self.verify_attr_dict()
self.verify_properties()
self.verify_operands(seen_variables)
self.verify_results(seen_variables)
return FormatProgram(elements)
Expand All @@ -155,6 +159,7 @@ def add_reserved_attrs_to_directive(self, elements: list[FormatDirective]):
elements[idx] = AttrDictDirective(
with_keyword=element.with_keyword,
reserved_attr_names=self.seen_attributes,
print_properties=element.print_properties,
)
return

Expand Down Expand Up @@ -222,6 +227,30 @@ def verify_attr_dict(self):
if not self.has_attr_dict:
self.raise_error("'attr-dict' directive not found")

def verify_properties(self):
"""
Check that all properties are present, unless `ParsePropInAttrDict` option is
used.
"""
# This is used for compatibility with MLIR
if any(
isinstance(option, ParsePropInAttrDict) for option in self.op_def.options
):
if self.seen_properties:
self.raise_error(
"properties cannot be specified in the declarative format "
"when 'ParsePropInAttrDict' IRDL option is used. They are instead "
"parsed from the attribute dictionary."
)
return
missing_properties = set(self.op_def.properties.keys()) - self.seen_properties
if missing_properties:
self.raise_error(
f"{', '.join(missing_properties)} properties are missing from "
"the declarative format. If this is intentional, consider using "
"'ParsePropInAttrDict' IRDL option."
)

def parse_optional_variable(
self,
) -> OperandVariable | ResultVariable | AttributeVariable | None:
Expand Down Expand Up @@ -268,13 +297,19 @@ def parse_optional_variable(
# Check if the variable is an attribute
if variable_name in self.op_def.accessor_names:
(attr_name, attr_or_prop) = self.op_def.accessor_names[variable_name]
if attr_or_prop == "property":
self.raise_error("properties are currently not supported")
if self.context == ParsingContext.TopLevel:
if attr_name in self.seen_attributes:
self.raise_error(f"attribute '{variable_name}' is already bound")
self.seen_attributes.add(attr_name)
return AttributeVariable(variable_name)
if attr_or_prop == "property":
if attr_name in self.seen_properties:
self.raise_error(f"property '{variable_name}' is already bound")
self.seen_properties.add(attr_name)
else:
if attr_name in self.seen_attributes:
self.raise_error(
f"attribute '{variable_name}' is already bound"
)
self.seen_attributes.add(attr_name)

return AttributeVariable(variable_name, attr_or_prop == "property")

self.raise_error(
"expected variable to refer to an operand, "
Expand Down Expand Up @@ -394,6 +429,13 @@ def create_attr_dict_directive(self, with_keyword: bool) -> AttrDictDirective:
"in the assembly format description"
)
self.has_attr_dict = True
print_properties = any(
isinstance(option, ParsePropInAttrDict) for option in self.op_def.options
)
# reserved_attr_names is populated once the format is parsed, as some attributes
# might appear after the attr-dict directive
return AttrDictDirective(with_keyword=with_keyword, reserved_attr_names=set())
return AttrDictDirective(
with_keyword=with_keyword,
reserved_attr_names=set(),
print_properties=print_properties,
)
14 changes: 11 additions & 3 deletions xdsl/irdl/irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,16 @@ class AttrSizedSuccessorSegments(AttrSizedSegments):
"""Name of the attribute containing the variadic successor sizes."""


@dataclass
class ParsePropInAttrDict(IRDLOption):
"""
Parse properties in the attribute dictionary instead of requiring them to
be in the assembly format.
This should only be used to ensure MLIR compatibility, it is otherwise
bad design to use it.
"""


@dataclass
class OperandOrResultDef(ABC):
"""An operand or a result definition. Should not be used directly."""
Expand Down Expand Up @@ -1987,9 +1997,7 @@ def irdl_op_init(
f"Unexpected option {option} in operation definition {op_def}."
)
case _:
raise ValueError(
f"Unexpected option {option} in operation definition {op_def}."
)
pass

Operation.__init__(
self,
Expand Down

0 comments on commit 1ee55fd

Please sign in to comment.