Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Support DialectRegistry extension comparison #101119

Merged
merged 9 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions mlir/include/mlir/IR/DialectRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
#define MLIR_IR_DIALECTREGISTRY_H

#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"

#include <map>
#include <tuple>
Expand Down Expand Up @@ -187,7 +187,8 @@ class DialectRegistry {
nameAndRegistrationIt.second.second);
// Merge the extensions.
for (const auto &extension : extensions)
destination.extensions.push_back(extension->clone());
destination.extensions.emplace_back(extension.first,
extension.second->clone());
}

/// Return the names of dialects known to this registry.
Expand All @@ -206,47 +207,47 @@ class DialectRegistry {
void applyExtensions(MLIRContext *ctx) const;

/// Add the given extension to the registry.
void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
extensions.push_back(std::move(extension));
void addExtension(TypeID extensionID,
std::unique_ptr<DialectExtensionBase> extension) {
extensions.emplace_back(extensionID, std::move(extension));
}

/// Add the given extensions to the registry.
template <typename... ExtensionsT>
void addExtensions() {
(addExtension(std::make_unique<ExtensionsT>()), ...);
(addExtension(TypeID::get<ExtensionsT>(), std::make_unique<ExtensionsT>()),
...);
}

/// Add an extension function that requires the given dialects.
/// Note: This bare functor overload is provided in addition to the
/// std::function variant to enable dialect type deduction, e.g.:
/// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
/// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) {
/// ... })
///
/// is equivalent to:
/// registry.addExtension<MyDialect>(
/// [](MLIRContext *ctx, MyDialect *dialect){ ... }
/// )
template <typename... DialectsT>
void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
addExtension<DialectsT...>(
std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
}
template <typename... DialectsT>
void
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have to remove the std::function overload of addExtension so we can use the underlying function pointer as the TypeID, but the lost functionality can be achieved with a DialectExtension subclass if needed.

addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...);

struct Extension : public DialectExtension<Extension, DialectsT...> {
Extension(const Extension &) = default;
Extension(ExtensionFnT extensionFn)
: extensionFn(std::move(extensionFn)) {}
: DialectExtension<Extension, DialectsT...>(),
extensionFn(extensionFn) {}
~Extension() override = default;

void apply(MLIRContext *context, DialectsT *...dialects) const final {
extensionFn(context, dialects...);
}
ExtensionFnT extensionFn;
};
addExtension(std::make_unique<Extension>(std::move(extensionFn)));
addExtension(TypeID::getFromOpaquePointer(
reinterpret_cast<const void *>(extensionFn)),
std::make_unique<Extension>(extensionFn));
}

/// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
Expand All @@ -255,7 +256,9 @@ class DialectRegistry {

private:
MapTy registry;
std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
using KeyExtensionPair =
std::pair<TypeID, std::unique_ptr<DialectExtensionBase>>;
llvm::SmallVector<KeyExtensionPair> extensions;
};

} // namespace mlir
Expand Down
19 changes: 15 additions & 4 deletions mlir/lib/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -260,7 +261,7 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const {

// Note: Additional extensions may be added while applying an extension.
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
applyExtension(*extensions[i]);
applyExtension(*extensions[i].second);
}

void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
Expand All @@ -287,13 +288,23 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {

// Note: Additional extensions may be added while applying an extension.
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
applyExtension(*extensions[i]);
applyExtension(*extensions[i].second);
}

bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
// Treat any extensions conservatively.
if (!extensions.empty())
// Check that all extension keys are present in 'rhs'.
llvm::DenseSet<TypeID> rhsExtensionKeys;
{
auto rhsKeys = llvm::map_range(rhs.extensions,
[](const auto &item) { return item.first; });
rhsExtensionKeys.insert(rhsKeys.begin(), rhsKeys.end());
}

if (!llvm::all_of(extensions, [&rhsExtensionKeys](const auto &extension) {
return rhsExtensionKeys.contains(extension.first);
}))
return false;

// Check that the current dialects fully overlap with the dialects in 'rhs'.
return llvm::all_of(
registry, [&](const auto &it) { return rhs.registry.count(it.first); });
Expand Down
56 changes: 50 additions & 6 deletions mlir/unittests/IR/DialectTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/Support/TypeID.h"
#include "gtest/gtest.h"

using namespace mlir;
Expand Down Expand Up @@ -140,15 +141,22 @@ namespace {
/// A dummy extension that increases a counter when being applied and
/// recursively adds additional extensions.
struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyExtension)

DummyExtension(int *counter, int numRecursive)
: DialectExtension(), counter(counter), numRecursive(numRecursive) {}

void apply(MLIRContext *ctx, TestDialect *dialect) const final {
++(*counter);
DialectRegistry nestedRegistry;
for (int i = 0; i < numRecursive; ++i)
nestedRegistry.addExtension(
std::make_unique<DummyExtension>(counter, /*numRecursive=*/0));
for (int i = 0; i < numRecursive; ++i) {
// Create unique TypeIDs for these recursive extensions so they don't get
// de-duplicated.
auto extension =
std::make_unique<DummyExtension>(counter, /*numRecursive=*/0);
auto typeID = TypeID::getFromOpaquePointer(extension.get());
nestedRegistry.addExtension(typeID, std::move(extension));
}
// Adding additional extensions may trigger a reallocation of the
// `extensions` vector in the dialect registry.
ctx->appendDialectRegistry(nestedRegistry);
Expand All @@ -166,20 +174,56 @@ TEST(Dialect, NestedDialectExtension) {

// Add an extension that adds 100 more extensions.
int counter1 = 0;
registry.addExtension(std::make_unique<DummyExtension>(&counter1, 100));
registry.addExtension(TypeID::get<DummyExtension>(),
std::make_unique<DummyExtension>(&counter1, 100));
// Add one more extension. This should not crash.
int counter2 = 0;
registry.addExtension(std::make_unique<DummyExtension>(&counter2, 0));
registry.addExtension(TypeID::get<DummyExtension>(),
std::make_unique<DummyExtension>(&counter2, 0));

// Load dialect and apply extensions.
MLIRContext context(registry);
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);

// Extensions may be applied multiple times. Make sure that each expected
// Extensions are de-duplicated by typeID. Make sure that each expected
// extension was applied at least once.
EXPECT_GE(counter1, 101);
EXPECT_GE(counter2, 1);
}

TEST(Dialect, SubsetWithExtensions) {
DialectRegistry registry1, registry2;
registry1.insert<TestDialect>();
registry2.insert<TestDialect>();

// Validate that the registries are equivalent.
ASSERT_TRUE(registry1.isSubsetOf(registry2));
ASSERT_TRUE(registry2.isSubsetOf(registry1));

// Add extensions to registry2.
int counter = 0;
registry2.addExtension(TypeID::get<DummyExtension>(),
std::make_unique<DummyExtension>(&counter, 0));

// Expect that (1) is a subset of (2) but not the other way around.
ASSERT_TRUE(registry1.isSubsetOf(registry2));
ASSERT_FALSE(registry2.isSubsetOf(registry1));

// Add extensions to registry1.
registry1.addExtension(TypeID::get<DummyExtension>(),
std::make_unique<DummyExtension>(&counter, 0));

// Expect that (1) and (2) are equivalent.
ASSERT_TRUE(registry1.isSubsetOf(registry2));
ASSERT_TRUE(registry2.isSubsetOf(registry1));

// Load dialect and apply extensions.
MLIRContext context(registry1);
context.getOrLoadDialect<TestDialect>();
context.appendDialectRegistry(registry2);
// Expect that the extension as only invoked once.
ASSERT_EQ(counter, 1);
}

} // namespace
Loading