Skip to content

Commit

Permalink
core: Add support for nested inferrence in IRDL (#2005)
Browse files Browse the repository at this point in the history
This PR allows us to infer the types of some operands and results from
the types of other operands and results.

It does so by adding three functions in `AttrConstraint`:
* `get_resolved_variables`: This function returns the set of
`ConstraintVar` that will be set during verification.
* `can_infer`: Returns `True` if we can infer a unique attribute
represented by the `AttrConstraint` from the assignment of
`ConstraintVar`.
* `infer`: Returns the only attribute that satisfy this constraint,
given the constraint variables.

We use these methods to get all `ConstraintVar` from the parsed
operand/result types present in the syntax, then infer the constraints
on the types that are not in the syntax.

---------

Co-authored-by: Emilien Bauer <bauer.emilien@gmail.com>
Co-authored-by: Sasha Lopoukhine <superlopuh@gmail.com>
  • Loading branch information
3 people authored Jan 25, 2024
1 parent 53ea859 commit efe08f9
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 134 deletions.
220 changes: 184 additions & 36 deletions tests/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,35 @@

import textwrap
from io import StringIO
from typing import Annotated
from typing import Annotated, Generic, TypeVar

import pytest

from xdsl.dialects.builtin import ModuleOp
from xdsl.dialects.test import Test
from xdsl.ir import Attribute, MLContext, Operation
from xdsl.ir import (
Attribute,
MLContext,
Operation,
ParametrizedAttribute,
TypeAttribute,
)
from xdsl.irdl import (
AllOf,
AnyAttr,
ConstraintVar,
EqAttrConstraint,
IRDLOperation,
ParameterDef,
irdl_attr_definition,
irdl_op_definition,
operand_def,
opt_prop_def,
result_def,
)
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.utils.exceptions import PyRDLOpDefinitionError
from xdsl.utils.exceptions import ParseError, PyRDLOpDefinitionError

################################################################################
# Utils for this test file #
Expand Down Expand Up @@ -321,7 +332,7 @@ class NoOperandTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
def test_operands_missing_type():
"""Test that operands should have their type parsed"""
with pytest.raises(
PyRDLOpDefinitionError, match="type of operand 'operand' not found"
PyRDLOpDefinitionError, match="type of operand 'operand' cannot be inferred"
):

@irdl_op_definition
Expand Down Expand Up @@ -421,45 +432,16 @@ class TwoOperandsOp(IRDLOperation):
check_roundtrip(program, ctx)


@pytest.mark.parametrize(
"format",
[
"$lhs $rhs attr-dict `:` type($lhs)",
"$lhs $rhs attr-dict `:` type($rhs)",
"$lhs $rhs attr-dict `:` type($res)",
],
)
def test_vasic_inference(format: str):
@irdl_op_definition
class TwoOperandsOneResultWithVarOp(IRDLOperation):
T = Annotated[Attribute, ConstraintVar("T")]

name = "test.two_operands_one_result_with_var"
res = result_def(T)
lhs = operand_def(T)
rhs = operand_def(T)

assembly_format = format

ctx = MLContext()
ctx.load_op(TwoOperandsOneResultWithVarOp)
ctx.load_dialect(Test)
program = textwrap.dedent(
"""\
%0, %1 = "test.op"() : () -> (i32, i32)
%2 = test.two_operands_one_result_with_var %0 %1 : i32"""
)
check_roundtrip(program, ctx)


################################################################################
# Results #
################################################################################


def test_missing_result_type():
"""Test that results should have their type parsed."""
with pytest.raises(PyRDLOpDefinitionError, match="result 'result' not found"):
with pytest.raises(
PyRDLOpDefinitionError, match="result 'result' cannot be inferred"
):

@irdl_op_definition
class NoResultTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
Expand Down Expand Up @@ -522,3 +504,169 @@ class TwoResultOp(IRDLOperation):

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


################################################################################
# Inference #
################################################################################

_T = TypeVar("_T", bound=Attribute)


@pytest.mark.parametrize(
"format",
[
"$lhs $rhs attr-dict `:` type($lhs)",
"$lhs $rhs attr-dict `:` type($rhs)",
"$lhs $rhs attr-dict `:` type($res)",
],
)
def test_basic_inference(format: str):
"""Check that we can infer the type of an operand when ConstraintVar are used."""

@irdl_op_definition
class TwoOperandsOneResultWithVarOp(IRDLOperation):
T = Annotated[Attribute, ConstraintVar("T")]

name = "test.two_operands_one_result_with_var"
res = result_def(T)
lhs = operand_def(T)
rhs = operand_def(T)

assembly_format = format

ctx = MLContext()
ctx.load_op(TwoOperandsOneResultWithVarOp)
ctx.load_dialect(Test)
program = textwrap.dedent(
"""\
%0, %1 = "test.op"() : () -> (i32, i32)
%2 = test.two_operands_one_result_with_var %0 %1 : i32
"test.op"(%2) : (i32) -> ()"""
)
check_roundtrip(program, ctx)


def test_eq_attr_inference():
"""Check that operands/results with a fixed type can be inferred."""

@irdl_attr_definition
class UnitType(ParametrizedAttribute, TypeAttribute):
name = "test.unit"

@irdl_op_definition
class OneOperandEqType(IRDLOperation):
name = "test.one_operand_eq_type"
index = operand_def(UnitType())
res = result_def(UnitType())

assembly_format = "attr-dict $index"

ctx = MLContext()
ctx.load_attr(UnitType)
ctx.load_op(OneOperandEqType)
ctx.load_dialect(Test)
program = textwrap.dedent(
"""\
%0 = "test.op"() : () -> !test.unit
%1 = test.one_operand_eq_type %0
"test.op"(%1) : (!test.unit) -> ()"""
)
check_roundtrip(program, ctx)


def test_all_of_attr_inference():
"""Check that AllOf still allows for inference."""

@irdl_attr_definition
class UnitType(ParametrizedAttribute, TypeAttribute):
name = "test.unit"

@irdl_op_definition
class OneOperandEqTypeAllOfNested(IRDLOperation):
name = "test.one_operand_eq_type_all_of_nested"
index = operand_def(AllOf([AnyAttr(), EqAttrConstraint(UnitType())]))

assembly_format = "attr-dict $index"

ctx = MLContext()
ctx.load_attr(UnitType)
ctx.load_op(OneOperandEqTypeAllOfNested)
ctx.load_dialect(Test)
program = textwrap.dedent(
"""\
%0 = "test.op"() : () -> !test.unit
test.one_operand_eq_type_all_of_nested %0"""
)
check_roundtrip(program, ctx)


def test_nested_inference():
"""Check that Param<T> infers correctly T."""

@irdl_attr_definition
class ParamOne(ParametrizedAttribute, TypeAttribute, Generic[_T]):
name = "test.param_one"

n: ParameterDef[Attribute]
p: ParameterDef[_T]
q: ParameterDef[Attribute]

@irdl_op_definition
class TwoOperandsNestedVarOp(IRDLOperation):
T = Annotated[Attribute, ConstraintVar("T")]

name = "test.two_operands_one_result_with_var"
res = result_def(T)
lhs = operand_def(ParamOne[T])
rhs = operand_def(T)

assembly_format = "$lhs $rhs attr-dict `:` type($lhs)"

ctx = MLContext()
ctx.load_op(TwoOperandsNestedVarOp)
ctx.load_attr(ParamOne)
ctx.load_dialect(Test)
program = textwrap.dedent(
"""\
%0, %1 = "test.op"() : () -> (!test.param_one<f16, i32, i1>, i32)
%2 = test.two_operands_one_result_with_var %0 %1 : !test.param_one<f16, i32, i1>"""
)
check_roundtrip(program, ctx)


def test_non_verifying_inference():
"""
Check that non-verifying operands/results will
trigger a ParseError when inference is required.
"""

@irdl_attr_definition
class ParamOne(ParametrizedAttribute, TypeAttribute, Generic[_T]):
name = "test.param_one"
p: ParameterDef[_T]

@irdl_op_definition
class OneOperandOneResultNestedOp(IRDLOperation):
T = Annotated[Attribute, ConstraintVar("T")]

name = "test.one_operand_one_result_nested"
res = result_def(T)
lhs = operand_def(ParamOne[T])

assembly_format = "$lhs attr-dict `:` type($lhs)"

ctx = MLContext()
ctx.load_op(OneOperandOneResultNestedOp)
ctx.load_attr(ParamOne)
ctx.load_dialect(Test)
program = textwrap.dedent(
"""\
%0 = "test.op"() : () -> i32
%1 = test.one_operand_one_result_nested %0 : i32"""
)
with pytest.raises(
ParseError,
match="Verification error while inferring operation type: ",
):
check_roundtrip(program, ctx)
Loading

0 comments on commit efe08f9

Please sign in to comment.