-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][DLTI] Enable types as keys in DLTI-query utils #105995
Conversation
@llvm/pr-subscribers-mlir-dlti @llvm/pr-subscribers-mlir Author: Rolf Morel (rolfmorel) ChangesEnable support for query functions - include transform.dlti.query - to take types as keys. As the data layout specific attributes already supported types as keys, this change enables querying such attributes in the expected way. Full diff: https://github.com/llvm/llvm-project/pull/105995.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTI.h b/mlir/include/mlir/Dialect/DLTI/DLTI.h
index a97eb523cb0631..f268fea340a6fb 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTI.h
+++ b/mlir/include/mlir/Dialect/DLTI/DLTI.h
@@ -26,7 +26,7 @@ namespace mlir {
namespace dlti {
/// Perform a DLTI-query at `op`, recursively querying each key of `keys` on
/// query interface-implementing attrs, starting from attr obtained from `op`.
-FailureOr<Attribute> query(Operation *op, ArrayRef<StringAttr> keys,
+FailureOr<Attribute> query(Operation *op, ArrayRef<DataLayoutEntryKey> keys,
bool emitError = false);
} // namespace dlti
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td b/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
index 1b1bebfaab4e38..f25bb383912d45 100644
--- a/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
+++ b/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
@@ -26,9 +26,10 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
A lookup is performed for the given `keys` at `target` op - or its closest
interface-implementing ancestor - by way of the `DLTIQueryInterface`, which
- returns an attribute for a key. If more than one key is provided, the lookup
- continues recursively, now on the returned attributes, with the condition
- that these implement the above interface. For example if the payload IR is
+ returns an attribute for a key. Each key should be either a (quoted) string
+ or a type. If more than one key is provided, the lookup continues
+ recursively, now on the returned attributes, with the condition that these
+ implement the above interface. For example if the payload IR is
```
module attributes {#dlti.map = #dlti.map<#dlti.dl_entry<"A",
@@ -52,7 +53,7 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- StrArrayAttr:$keys);
+ ArrayAttr:$keys);
let results = (outs TransformParamTypeInterface:$associated_attr);
let assemblyFormat =
"$keys `at` $target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 7f8e11a1b73341..58f8799b0714d7 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -424,8 +424,8 @@ getClosestQueryable(Operation *op) {
return std::pair(queryable, op);
}
-FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
- bool emitError) {
+FailureOr<Attribute>
+dlti::query(Operation *op, ArrayRef<DataLayoutEntryKey> keys, bool emitError) {
auto [queryable, queryOp] = getClosestQueryable(op);
Operation *reportOp = (queryOp ? queryOp : op);
@@ -438,6 +438,17 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
return failure();
}
+ auto keyToStr = [](DataLayoutEntryKey key) -> std::string {
+ if (auto strKey = llvm::dyn_cast<StringAttr>(key))
+ return "\"" + std::string(strKey.getValue()) + "\"";
+ if (auto typeKey = llvm::dyn_cast<Type>(key)) {
+ std::string buf;
+ llvm::raw_string_ostream(buf) << typeKey;
+ return buf;
+ }
+ llvm_unreachable("DataLayoutEntryKey was not `StringAttr` or `Type`");
+ };
+
Attribute currentAttr = queryable;
for (auto &&[idx, key] : llvm::enumerate(keys)) {
if (auto map = llvm::dyn_cast<DLTIQueryInterface>(currentAttr)) {
@@ -446,17 +457,24 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
if (emitError) {
auto diag = op->emitError() << "target op of failed DLTI query";
diag.attachNote(reportOp->getLoc())
- << "key " << key << " has no DLTI-mapping per attr: " << map;
+ << "key " << keyToStr(key)
+ << " has no DLTI-mapping per attr: " << map;
}
return failure();
}
currentAttr = *maybeAttr;
} else {
if (emitError) {
+ std::string commaSeparatedKeys;
+ llvm::interleave(
+ keys.take_front(idx), // All prior keys.
+ [&](auto key) { commaSeparatedKeys += keyToStr(key); },
+ [&]() { commaSeparatedKeys += ","; });
+
auto diag = op->emitError() << "target op of failed DLTI query";
diag.attachNote(reportOp->getLoc())
<< "got non-DLTI-queryable attribute upon looking up keys ["
- << keys.take_front(idx) << "] at op";
+ << commaSeparatedKeys << "] at op";
}
return failure();
}
diff --git a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
index 90aef82bddff00..2f171a8375b46d 100644
--- a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
+++ b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
@@ -33,7 +33,16 @@ void transform::QueryOp::getEffects(
DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results, TransformState &state) {
- auto keys = SmallVector<StringAttr>(getKeys().getAsRange<StringAttr>());
+ auto keys = SmallVector<DataLayoutEntryKey>();
+ for (Attribute key : getKeys()) {
+ if (auto strKey = dyn_cast<StringAttr>(key))
+ keys.push_back(strKey);
+ else if (auto typeKey = dyn_cast<TypeAttr>(key))
+ keys.push_back(typeKey.getValue());
+ else
+ return emitDefiniteFailure("'transform.dlti.query' keys of wrong type: "
+ "only StringAttr and TypeAttr are allowed");
+ }
FailureOr<Attribute> result = dlti::query(target, keys, /*emitError=*/true);
diff --git a/mlir/test/Dialect/DLTI/invalid.mlir b/mlir/test/Dialect/DLTI/invalid.mlir
index 05f919fa256713..4b04f0195ef823 100644
--- a/mlir/test/Dialect/DLTI/invalid.mlir
+++ b/mlir/test/Dialect/DLTI/invalid.mlir
@@ -33,6 +33,14 @@
// -----
+// expected-error@below {{repeated layout entry key: 'i32'}}
+"test.unknown_op"() { test.unknown_attr = #dlti.map<
+ #dlti.dl_entry<i32, 42>,
+ #dlti.dl_entry<i32, 42>
+>} : () -> ()
+
+// -----
+
// expected-error@below {{repeated layout entry key: 'i32'}}
"test.unknown_op"() { test.unknown_attr = #dlti.dl_spec<
#dlti.dl_entry<i32, 42>,
diff --git a/mlir/test/Dialect/DLTI/query.mlir b/mlir/test/Dialect/DLTI/query.mlir
index 10e91afd2ca7e1..e449c2c44bc617 100644
--- a/mlir/test/Dialect/DLTI/query.mlir
+++ b/mlir/test/Dialect/DLTI/query.mlir
@@ -17,6 +17,60 @@ module attributes {transform.with_named_sequence} {
// -----
+// expected-remark @below {{associated attr 42 : i32}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, 42 : i32>>} {
+ func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+ %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
+ %param = transform.dlti.query [i32] at %module : (!transform.any_op) -> !transform.any_param
+ transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// expected-remark @below {{associated attr 32 : i32}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, #dlti.map<#dlti.dl_entry<"width_in_bits", 32 : i32>>>>} {
+ func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+ %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
+ %param = transform.dlti.query [i32,"width_in_bits"] at %module : (!transform.any_op) -> !transform.any_param
+ transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// expected-remark @below {{width in bits of i32 = 32 : i64}}
+// expected-remark @below {{width in bits of f64 = 64 : i64}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>, #dlti.dl_entry<f64, 64>>>>} {
+ func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+ %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
+ %i32bits = transform.dlti.query ["width_in_bits",i32] at %module : (!transform.any_op) -> !transform.any_param
+ %f64bits = transform.dlti.query ["width_in_bits",f64] at %module : (!transform.any_op) -> !transform.any_param
+ transform.debug.emit_param_as_remark %i32bits, "width in bits of i32 =" at %module : !transform.any_param, !transform.any_op
+ transform.debug.emit_param_as_remark %f64bits, "width in bits of f64 =" at %module : !transform.any_param, !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// expected-remark @below {{associated attr 42 : i32}}
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
func.func private @f()
@@ -336,6 +390,23 @@ module attributes {transform.with_named_sequence} {
// -----
+// expected-note @below {{got non-DLTI-queryable attribute upon looking up keys [i32]}}
+module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<i32, 32 : i32>>} {
+ // expected-error @below {{target op of failed DLTI query}}
+ func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{'transform.dlti.query' op failed to apply}}
+ %param = transform.dlti.query [i32,"width_in_bits"] at %func : (!transform.any_op) -> !transform.any_param
+ transform.yield
+ }
+}
+
+// -----
+
module {
// expected-error @below {{target op of failed DLTI query}}
// expected-note @below {{no DLTI-queryable attrs on target op or any of its ancestors}}
@@ -353,6 +424,23 @@ module attributes {transform.with_named_sequence} {
// -----
+// expected-note @below {{key i64 has no DLTI-mapping per attr: #dlti.map<#dlti.dl_entry<i32, 32 : i64>>}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>>>>} {
+ // expected-error @below {{target op of failed DLTI query}}
+ func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{'transform.dlti.query' op failed to apply}}
+ %param = transform.dlti.query ["width_in_bits",i64] at %func : (!transform.any_op) -> !transform.any_param
+ transform.yield
+ }
+}
+
+// -----
+
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
func.func private @f()
}
diff --git a/mlir/test/Dialect/DLTI/valid.mlir b/mlir/test/Dialect/DLTI/valid.mlir
index 4133eac5424ce8..023caf6ac5a05f 100644
--- a/mlir/test/Dialect/DLTI/valid.mlir
+++ b/mlir/test/Dialect/DLTI/valid.mlir
@@ -206,3 +206,18 @@ module attributes {
"GPU": #dlti.target_device_spec<
#dlti.dl_entry<"L1_cache_size_in_bytes", "128">>
>} {}
+
+
+// -----
+
+// CHECK: "test.op_with_dlti_map"() ({
+// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42 : i64>>}
+"test.op_with_dlti_map"() ({
+}) { dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42>> } : () -> ()
+
+// -----
+
+// CHECK: "test.op_with_dlti_map"() ({
+// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<i32, 42 : i64>>}
+"test.op_with_dlti_map"() ({
+}) { dlti.map = #dlti.map<#dlti.dl_entry<i32, 42>> } : () -> ()
\ No newline at end of file
|
Hi @rengolin, @ftynse, @banach-space, @joker-eph, @adam-smnk & @Dinistro, as I think you are interested in making progress on making use of DLTI's target descriptors, I thought to ping you on this. If any of you could help with review, that would be appreciated - thanks! |
a97ad19
to
b087fdc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the only use-case a list of types for some parent property?
Is this so much better than just having a string "f32" and then parsing it into a type?
Honest questions, I can't follow what you're trying to reach, here.
To me there are two main arguments for wanting this:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine, +1 for completeness
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am concerned about the amount of string manipulation involved here.
StringAttr and Type are uniqued pointers in the context and we should be able to take advantage of this here for simple pointer-based keys here.
I think there's a misunderstanding here, @joker-eph : only the erroring-out path has any string manipulation, and only when requested, i.e. all string manipulation is guarded by The key comparison itself is delegated to the
== of PointerUnion<StringAttr,Type> , this does make use of (void * ) pointer comparisons for equality checks. Note as well that currently all DLTI attributes have essentially this same query() implementation.
If you are concerned about the amount of string manipulation on the error-reporting path, I am happy to discuss that. If there's consensus in favour of less string manipulation even if that means less informative error messages, this is something I could work with on. |
Enable support for query functions - including transform.dlti.query - to take types as keys. As the data layout specific attributes already supported types as keys, this change enables querying such attributes in the expected way.
5758973
to
e1899e3
Compare
Thanks for the review, @rengolin, @adam-smnk and @joker-eph ! I just now did a rebase, a squash and a check-all. I think this PR is good to go. If somebody could help with merging, that would be appreciated! |
Enable support for query functions - including transform.dlti.query - to take types as keys. As the data layout specific attributes already supported types as keys, this change enables querying such attributes in the expected way.