Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Support DialectRegistry extension comparison #101119

Merged
merged 9 commits into from
Aug 5, 2024

Conversation

nikalra
Copy link
Contributor

@nikalra nikalra commented Jul 30, 2024

PassManager::run loads the dependent dialects for each pass into the current context prior to invoking the individual passes. If the dependent dialect is already loaded into the context, this should be a no-op. However, if there are extensions registered in the DialectRegistry, the dependent dialects are unconditionally registered into the context.

This poses a problem for dynamic pass pipelines, however, because they will likely be executing while the context is in an immutable state (because of the parent pass pipeline being run).

To solve this, we'll update the extension registration API on DialectRegistry to require a type ID for each extension that is registered. Then, instead of unconditionally registered dialects into a context if extensions are present, we'll check against the extension type IDs already present in the context's internal DialectRegistry. The context will only be marked as dirty if there are net-new extension types present in the DialectRegistry populated by PassManager::getDependentDialects.

Note: this PR removes the addExtension overload that utilizes std::function as the parameter. This is because std::function is copyable and potentially allocates memory for the contained function so we can't use the function pointer as the unique type ID for the extension.

Downstream changes required:

  • Existing DialectExtension subclasses will need a type ID to be registered for each subclass. More details on how to register a type ID can be found here:
    /// This class provides an efficient unique identifier for a specific C++ type.
  • Existing uses of the std::function overload of addExtension will need to be refactored into dedicated DialectExtension classes with associated type IDs. The attached std::function can either be inlined into or called directly from DialectExtension::apply.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jul 30, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 30, 2024

@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-cf
@llvm/pr-subscribers-openacc
@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-sme
@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir-func
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-mlprogram
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Nikhil Kalra (nikalra)

Changes

The current assert in MLIRContext::appendDialectRegistry effectively prevents passes that leverage complex dialects (i.e. dialects with extensions) from being used in a dynamic pass pipeline. More specifically, registry.isSubsetOf automatically returns false for any dialect registry that has dialect extensions registered, even if the dialects and extensions have already been loaded into the current context.

This change relaxes the assert such that the context dialect registry can be updated inside of a dynamic pass pipeline when in single-threaded mode.


Full diff: https://github.com/llvm/llvm-project/pull/101119.diff

1 Files Affected:

  • (modified) mlir/lib/IR/MLIRContext.cpp (+4-3)
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 12336701c9ca0..14316874e4743 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -410,9 +410,10 @@ void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
   if (registry.isSubsetOf(impl->dialectsRegistry))
     return;
 
-  assert(impl->multiThreadedExecutionContext == 0 &&
-         "appending to the MLIRContext dialect registry while in a "
-         "multi-threaded execution context");
+  assert(!impl->threadingIsEnabled ||
+         impl->multiThreadedExecutionContext == 0 &&
+             "appending to the MLIRContext dialect registry while in a "
+             "multi-threaded execution context");
   registry.appendTo(impl->dialectsRegistry);
 
   // For the already loaded dialects, apply any possible extensions immediately.

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether threading is enabling or not should not affect the invariants of the pass manager, the assert is very intentional in its current form.

Seems like an issue with detecting that some loading is a no-op?

@nikalra
Copy link
Contributor Author

nikalra commented Aug 1, 2024

Whether threading is enabling or not should not affect the invariants of the pass manager, the assert is very intentional in its current form.

Seems like an issue with detecting that some loading is a no-op?

It seems like this should be a better approach by allowing us to check if extensions have already been registered into the context (instead of conservatively assuming they aren't). But, this change touches 90+ extra files, so if there's a way to do this more cleanly I'm more than happy to investigate that further!

@joker-eph
Copy link
Collaborator

joker-eph commented Aug 1, 2024

Thanks for giving it a try! I feel that relying on a string provided by the user is not desirable. We need to revisit this with the same kind of mechanism we used elsewhere (based on TypeID / global symbol address) somehow.

You probably want to narrow this down to one example that we can iterate on before engaging in large scale changes.

Copy link
Contributor Author

@nikalra nikalra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully this is closer to what you had in mind!

std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
}
template <typename... DialectsT>
void
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have to remove the std::function overload of addExtension so we can use the underlying function pointer as the TypeID, but the lost functionality can be achieved with a DialectExtension subclass if needed.

Copy link
Contributor

@River707 River707 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a lot cleaner! Will this require any changes to existing in-tree usages? I do appreciate not needing to change ~every current usage.

@nikalra
Copy link
Contributor Author

nikalra commented Aug 1, 2024

This looks a lot cleaner! Will this require any changes to existing in-tree usages? I do appreciate not needing to change ~every current usage.

I think just a handful: we should get most of the lambdas for free since I don't think any of them have captures, but we'll have to add type ID macros to the DialectExtension subclasses.

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks :)

@joker-eph
Copy link
Collaborator

Can you add in the PR description a note about the removed API and the advice on how to proceed for people affected (how to replace it).

Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
@nikalra
Copy link
Contributor Author

nikalra commented Aug 1, 2024

Can you add in the PR description a note about the removed API and the advice on how to proceed for people affected (how to replace it).

Done!

@nikalra nikalra requested review from joker-eph and River707 August 5, 2024 19:58
@nikalra
Copy link
Contributor Author

nikalra commented Aug 5, 2024

If this change looks good, would it be possible to merge it in? I don't have merge permissions for LLVM. Thanks!

@joker-eph
Copy link
Collaborator

(Waiting for CI to complete)

@joker-eph
Copy link
Collaborator

Can you add in the PR description a note about the removed API and the advice on how to proceed for people affected (how to replace it)

Actually this seems missing? Or maybe I wasn't clear: downstream folks who will broken by this patch would appreciate in the description a step-by-step list of actions they need to take to upgrade their codebase.

@nikalra
Copy link
Contributor Author

nikalra commented Aug 5, 2024

Can you add in the PR description a note about the removed API and the advice on how to proceed for people affected (how to replace it)

Actually this seems missing? Or maybe I wasn't clear: downstream folks who will broken by this patch would appreciate in the description a step-by-step list of actions they need to take to upgrade their codebase.

Added step-by-step instructions for the two modified use cases: existing DialectExtensions and the removed overload. Please let me know if I'm missing anything there or if there's anything that won't be clear to affected users!

@joker-eph joker-eph merged commit 84cc186 into llvm:main Aug 5, 2024
7 checks passed
@joker-eph
Copy link
Collaborator

Perfect, thanks!

fifield added a commit to Xilinx/mlir-aie that referenced this pull request Aug 13, 2024
@nikalra nikalra deleted the dialect-registration-single-threaded branch February 27, 2025 00:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants