Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Support DialectRegistry extension comparison #101119

Merged
merged 9 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/examples/transform/Ch2/lib/MyExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
class MyExtension
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
public:
// The TypeID of this extension.
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)

// The extension must derive the base constructor.
using Base::Base;

Expand Down
3 changes: 3 additions & 0 deletions mlir/examples/transform/Ch3/lib/MyExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
class MyExtension
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
public:
// The TypeID of this extension.
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)

// The extension must derive the base constructor.
using Base::Base;

Expand Down
3 changes: 3 additions & 0 deletions mlir/examples/transform/Ch4/lib/MyExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
class MyExtension
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
public:
// The TypeID of this extension.
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)

// The extension must derive the base constructor.
using Base::Base;

Expand Down
37 changes: 19 additions & 18 deletions mlir/include/mlir/IR/DialectRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
#define MLIR_IR_DIALECTREGISTRY_H

#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/MapVector.h"

#include <map>
#include <tuple>
Expand Down Expand Up @@ -187,7 +187,8 @@ class DialectRegistry {
nameAndRegistrationIt.second.second);
// Merge the extensions.
for (const auto &extension : extensions)
destination.extensions.push_back(extension->clone());
destination.extensions.try_emplace(extension.first,
extension.second->clone());
}

/// Return the names of dialects known to this registry.
Expand All @@ -206,47 +207,47 @@ class DialectRegistry {
void applyExtensions(MLIRContext *ctx) const;

/// Add the given extension to the registry.
void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
extensions.push_back(std::move(extension));
bool addExtension(TypeID extensionID,
std::unique_ptr<DialectExtensionBase> extension) {
return extensions.try_emplace(extensionID, std::move(extension)).second;
}

/// Add the given extensions to the registry.
template <typename... ExtensionsT>
void addExtensions() {
(addExtension(std::make_unique<ExtensionsT>()), ...);
(addExtension(TypeID::get<ExtensionsT>(), std::make_unique<ExtensionsT>()),
...);
}

/// Add an extension function that requires the given dialects.
/// Note: This bare functor overload is provided in addition to the
/// std::function variant to enable dialect type deduction, e.g.:
/// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
/// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) {
/// ... })
///
/// is equivalent to:
/// registry.addExtension<MyDialect>(
/// [](MLIRContext *ctx, MyDialect *dialect){ ... }
/// )
template <typename... DialectsT>
void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
addExtension<DialectsT...>(
std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
}
template <typename... DialectsT>
void
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have to remove the std::function overload of addExtension so we can use the underlying function pointer as the TypeID, but the lost functionality can be achieved with a DialectExtension subclass if needed.

addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
bool addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...);

struct Extension : public DialectExtension<Extension, DialectsT...> {
Extension(const Extension &) = default;
Extension(ExtensionFnT extensionFn)
: extensionFn(std::move(extensionFn)) {}
: DialectExtension<Extension, DialectsT...>(),
extensionFn(extensionFn) {}
~Extension() override = default;

void apply(MLIRContext *context, DialectsT *...dialects) const final {
extensionFn(context, dialects...);
}
ExtensionFnT extensionFn;
};
addExtension(std::make_unique<Extension>(std::move(extensionFn)));
return addExtension(TypeID::getFromOpaquePointer(
reinterpret_cast<const void *>(extensionFn)),
std::make_unique<Extension>(extensionFn));
}

/// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
Expand All @@ -255,7 +256,7 @@ class DialectRegistry {

private:
MapTy registry;
std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
};

} // namespace mlir
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ namespace {
/// starting a pass pipeline that involves dialect conversion to LLVM.
class LoadDependentDialectExtension : public DialectExtensionBase {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)

LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}

void apply(MLIRContext *context,
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ class AffineTransformDialectExtension
: public transform::TransformDialectExtension<
AffineTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension)

using Base::Base;

void init() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ class BufferizationTransformDialectExtension
: public transform::TransformDialectExtension<
BufferizationTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
BufferizationTransformDialectExtension)

using Base::Base;

void init() {
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ class FuncTransformDialectExtension
: public transform::TransformDialectExtension<
FuncTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension)

using Base::Base;

void init() {
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,8 @@ class GPUTransformDialectExtension
: public transform::TransformDialectExtension<
GPUTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)

GPUTransformDialectExtension() {
declareGeneratedDialect<scf::SCFDialect>();
declareGeneratedDialect<arith::ArithDialect>();
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class LinalgTransformDialectExtension
: public transform::TransformDialectExtension<
LinalgTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LinalgTransformDialectExtension)

using Base::Base;

void init() {
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ class MemRefTransformDialectExtension
: public transform::TransformDialectExtension<
MemRefTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemRefTransformDialectExtension)

using Base::Base;

void init() {
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,8 @@ class NVGPUTransformDialectExtension
: public transform::TransformDialectExtension<
NVGPUTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)

NVGPUTransformDialectExtension() {
declareGeneratedDialect<arith::ArithDialect>();
declareGeneratedDialect<affine::AffineDialect>();
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,8 @@ class SCFTransformDialectExtension
: public transform::TransformDialectExtension<
SCFTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension)

using Base::Base;

void init() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class SparseTensorTransformDialectExtension
: public transform::TransformDialectExtension<
SparseTensorTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
SparseTensorTransformDialectExtension)

SparseTensorTransformDialectExtension() {
declareGeneratedDialect<sparse_tensor::SparseTensorDialect>();
registerTransformOps<
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ class TensorTransformDialectExtension
: public transform::TransformDialectExtension<
TensorTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TensorTransformDialectExtension)

using Base::Base;

void init() {
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ namespace {
class DebugExtension
: public transform::TransformDialectExtension<DebugExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DebugExtension)

void init() {
registerTransformOps<
#define GET_OP_LIST
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ namespace {
class IRDLExtension
: public transform::TransformDialectExtension<IRDLExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IRDLExtension)

void init() {
registerTransformOps<
#define GET_OP_LIST
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ namespace {
class LoopExtension
: public transform::TransformDialectExtension<LoopExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoopExtension)

void init() {
registerTransformOps<
#define GET_OP_LIST
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ namespace {
/// with Transform dialect operations.
class PDLExtension : public transform::TransformDialectExtension<PDLExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PDLExtension)

void init() {
registerTransformOps<
#define GET_OP_LIST
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ class VectorTransformDialectExtension
: public transform::TransformDialectExtension<
VectorTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorTransformDialectExtension)

VectorTransformDialectExtension() {
declareGeneratedDialect<vector::VectorDialect>();
declareGeneratedDialect<LLVM::LLVMDialect>();
Expand Down
56 changes: 48 additions & 8 deletions mlir/lib/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Regex.h"
#include <memory>

#define DEBUG_TYPE "dialect"

Expand Down Expand Up @@ -173,6 +179,40 @@ bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect,
// DialectRegistry
//===----------------------------------------------------------------------===//

namespace {
template <typename Fn>
void applyExtensionsFn(
Fn &&applyExtension,
const llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>>
&extensions) {
// Note: Additional extensions may be added while applying an extension.
// The iterators will be invalidated if extensions are added so we'll keep
// a copy of the extensions for ourselves.

const auto extractExtension =
[](const auto &entry) -> DialectExtensionBase * {
return entry.second.get();
};

auto startIt = extensions.begin(), endIt = extensions.end();
size_t count = 0;
while (startIt != endIt) {
count += endIt - startIt;

// Grab the subset of extensions we'll apply in this iteration.
const auto subset =
llvm::map_to_vector(llvm::make_range(startIt, endIt), extractExtension);

for (const auto *ext : subset)
applyExtension(*ext);

// Book-keep for the next iteration.
startIt = extensions.begin() + count;
endIt = extensions.end();
}
}
} // namespace

DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }

DialectAllocatorFunctionRef
Expand Down Expand Up @@ -258,9 +298,7 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const {
extension.apply(ctx, requiredDialects);
};

// Note: Additional extensions may be added while applying an extension.
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
applyExtension(*extensions[i]);
applyExtensionsFn(applyExtension, extensions);
}

void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
Expand All @@ -285,15 +323,17 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
extension.apply(ctx, requiredDialects);
};

// Note: Additional extensions may be added while applying an extension.
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
applyExtension(*extensions[i]);
applyExtensionsFn(applyExtension, extensions);
}

bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
// Treat any extensions conservatively.
if (!extensions.empty())
// Check that all extension keys are present in 'rhs'.
const auto hasExtension = [&](const auto &key) {
return rhs.extensions.contains(key);
};
if (!llvm::all_of(make_first_range(extensions), hasExtension))
return false;

// Check that the current dialects fully overlap with the dialects in 'rhs'.
return llvm::all_of(
registry, [&](const auto &it) { return rhs.registry.count(it.first); });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,8 @@ class TestTransformDialectExtension
: public transform::TransformDialectExtension<
TestTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectExtension)

using Base::Base;

void init() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ class TestTilingInterfaceDialectExtension
: public transform::TransformDialectExtension<
TestTilingInterfaceDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestTilingInterfaceDialectExtension)

using Base::Base;

void init() {
Expand Down
2 changes: 2 additions & 0 deletions mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ using namespace mlir::transform;
namespace {
class Extension : public TransformDialectExtension<Extension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Extension)

using Base::Base;
void init() { declareGeneratedDialect<func::FuncDialect>(); }
};
Expand Down
Loading
Loading