Skip to content

Commit

Permalink
core: Fix bug in extractor when having nested variables (#3850)
Browse files Browse the repository at this point in the history
When having an attribute constraint variable `T` and a variable `U` that
could infer the
variable `T` (for instance `U = Vector<T>`), the constraint `U` would
not realize that it can infer `T`.
  • Loading branch information
math-fehr authored Feb 6, 2025
1 parent e9ab63b commit f1e8d25
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
38 changes: 38 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2399,6 +2399,44 @@ class TwoOperandsNestedVarOp(IRDLOperation):
check_roundtrip(program, ctx)


def test_nested_inference_variable():
"""Check that Param<T> infers correctly T when Param<T> is nested in a variable."""

@irdl_attr_definition
class ParamOne(ParametrizedAttribute, TypeAttribute, Generic[_T]):
name = "test.param_one"

p: ParameterDef[_T]

@classmethod
def constr(
cls, *, p: GenericAttrConstraint[_T] | None = None
) -> ParamAttrConstraint[ParamOne[_T]]:
return ParamAttrConstraint[ParamOne[_T]](ParamOne, (p,))

@irdl_op_definition
class ResultTypeIsOperandParamOp(IRDLOperation):
T: ClassVar = VarConstraint("T", AnyAttr())
U: ClassVar = VarConstraint("U", ParamOne.constr(p=T))

name = "test.result_type_is_operand_param"
res = result_def(T)
arg = operand_def(U)

assembly_format = "$arg attr-dict `:` type($arg)"

ctx = MLContext()
ctx.load_op(ResultTypeIsOperandParamOp)
ctx.load_attr(ParamOne)
ctx.load_dialect(Test)
program = textwrap.dedent(
"""\
%0 = "test.op"() : () -> !test.param_one<i32>
%1 = test.result_type_is_operand_param %0 : !test.param_one<i32>"""
)
check_roundtrip(program, ctx)


def test_non_verifying_inference():
"""
Check that non-verifying operands/results will
Expand Down
4 changes: 3 additions & 1 deletion xdsl/irdl/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def verify(
constraint_context.set_variable(self.name, attr)

def get_variable_extractors(self) -> dict[str, VarExtractor[AttributeCovT]]:
return {self.name: IdExtractor()}
return merge_extractor_dicts(
{self.name: IdExtractor()}, self.constraint.get_variable_extractors()
)

def infer(self, context: InferenceContext) -> AttributeCovT:
v = context.variables[self.name]
Expand Down

0 comments on commit f1e8d25

Please sign in to comment.