Skip to content

Commit

Permalink
Revisit Dialect registration: require and store a TypeID on dialects
Browse files Browse the repository at this point in the history
This patch moves the registration to a method in the MLIRContext: getOrCreateDialect<ConcreteDialect>()

This method requires dialect to provide a static getDialectNamespace()
and store a TypeID on the Dialect itself, which allows to lazyily
create a dialect when not yet loaded in the context.
As a side effect, it means that duplicated registration of the same
dialect is not an issue anymore.

To limit the boilerplate, TableGen dialect generation is modified to
emit the constructor entirely and invoke separately a "init()" method
that the user implements.

Differential Revision: https://reviews.llvm.org/D85495
  • Loading branch information
joker-eph committed Aug 7, 2020
1 parent d8c6d08 commit 575b22b
Show file tree
Hide file tree
Showing 31 changed files with 115 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ using namespace mlir::standalone;
// Standalone dialect.
//===----------------------------------------------------------------------===//

StandaloneDialect::StandaloneDialect(mlir::MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
void StandaloneDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "Standalone/StandaloneOps.cpp.inc"
Expand Down
3 changes: 2 additions & 1 deletion mlir/examples/toy/Ch2/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ using namespace mlir::toy;

/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
: mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
Expand Down
3 changes: 2 additions & 1 deletion mlir/examples/toy/Ch3/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ using namespace mlir::toy;

/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
: mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
Expand Down
3 changes: 2 additions & 1 deletion mlir/examples/toy/Ch4/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ struct ToyInlinerInterface : public DialectInlinerInterface {

/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
: mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
Expand Down
3 changes: 2 additions & 1 deletion mlir/examples/toy/Ch5/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ struct ToyInlinerInterface : public DialectInlinerInterface {

/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
: mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
Expand Down
3 changes: 2 additions & 1 deletion mlir/examples/toy/Ch6/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ struct ToyInlinerInterface : public DialectInlinerInterface {

/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
: mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
Expand Down
3 changes: 2 additions & 1 deletion mlir/examples/toy/Ch7/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ struct ToyInlinerInterface : public DialectInlinerInterface {

/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
: mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def LLVM_Dialect : Dialect {
private:
friend LLVMType;

std::unique_ptr<detail::LLVMDialectImpl> impl;
// This can't be a unique_ptr because the ctor is generated inline
// in the class definition at the moment.
detail::LLVMDialectImpl *impl;
}];
}

Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class MLIRContext;

class SDBMDialect : public Dialect {
public:
SDBMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) {}
SDBMDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {}

/// Since there are no other virtual methods in this derived class, override
/// the destructor so that key methods get defined in the corresponding
Expand Down
27 changes: 15 additions & 12 deletions mlir/include/mlir/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_IR_DIALECT_H

#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/TypeID.h"

namespace mlir {
class DialectAsmParser;
Expand Down Expand Up @@ -49,6 +50,9 @@ class Dialect {

StringRef getNamespace() const { return name; }

/// Returns the unique identifier that corresponds to this dialect.
TypeID getTypeID() const { return dialectID; }

/// Returns true if this dialect allows for unregistered operations, i.e.
/// operations prefixed with the dialect namespace but not registered with
/// addOperation.
Expand Down Expand Up @@ -177,7 +181,7 @@ class Dialect {
/// with the namespace followed by '.'.
/// Example:
/// - "tf" for the TensorFlow ops like "tf.add".
Dialect(StringRef name, MLIRContext *context);
Dialect(StringRef name, MLIRContext *context, TypeID id);

/// This method is used by derived classes to add their operations to the set.
///
Expand Down Expand Up @@ -223,13 +227,13 @@ class Dialect {
Dialect(const Dialect &) = delete;
void operator=(Dialect &) = delete;

/// Register this dialect object with the specified context. The context
/// takes ownership of the heap allocated dialect.
void registerDialect(MLIRContext *context);

/// The namespace of this dialect.
StringRef name;

/// The unique identifier of the derived Op class, this is used in the context
/// to allow registering multiple times the same dialect.
TypeID dialectID;

/// This is the context that owns this Dialect object.
MLIRContext *context;

Expand All @@ -255,7 +259,9 @@ class Dialect {
const DialectAllocatorFunction &function);
template <typename ConcreteDialect>
friend void registerDialect();
friend class MLIRContext;
};

/// Registers all dialects and hooks from the global registries with the
/// specified MLIRContext.
/// Note: This method is not thread-safe.
Expand All @@ -265,12 +271,9 @@ void registerAllDialects(MLIRContext *context);
/// global registry by calling registerDialect<MyDialect>();
/// Note: This method is not thread-safe.
template <typename ConcreteDialect> void registerDialect() {
Dialect::registerDialectAllocator(TypeID::get<ConcreteDialect>(),
[](MLIRContext *ctx) {
// Just allocate the dialect, the context
// takes ownership of it.
new ConcreteDialect(ctx);
});
Dialect::registerDialectAllocator(
TypeID::get<ConcreteDialect>(),
[](MLIRContext *ctx) { ctx->getOrCreateDialect<ConcreteDialect>(); });
}

/// DialectRegistration provides a global initializer that registers a Dialect
Expand All @@ -291,7 +294,7 @@ namespace llvm {
template <typename T>
struct isa_impl<T, ::mlir::Dialect> {
static inline bool doit(const ::mlir::Dialect &dialect) {
return T::getDialectNamespace() == dialect.getNamespace();
return mlir::TypeID::get<T>() == dialect.getTypeID();
}
};
} // namespace llvm
Expand Down
19 changes: 19 additions & 0 deletions mlir/include/mlir/IR/MLIRContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_IR_MLIRCONTEXT_H

#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include <functional>
#include <memory>
#include <vector>
Expand Down Expand Up @@ -49,6 +50,18 @@ class MLIRContext {
return static_cast<T *>(getRegisteredDialect(T::getDialectNamespace()));
}

/// Get (or create) a dialect for the given derived dialect type. The derived
/// type must provide a static 'getDialectNamespace' method.
template <typename T>
T *getOrCreateDialect() {
return static_cast<T *>(getOrCreateDialect(
T::getDialectNamespace(), TypeID::get<T>(), [this]() {
std::unique_ptr<T> dialect(new T(this));
dialect->dialectID = TypeID::get<T>();
return dialect;
}));
}

/// Return true if we allow to create operation for unregistered dialects.
bool allowsUnregisteredDialects();

Expand Down Expand Up @@ -109,6 +122,12 @@ class MLIRContext {
private:
const std::unique_ptr<MLIRContextImpl> impl;

/// Get a dialect for the provided namespace and TypeID: abort the program if
/// a dialect exist for this namespace with different TypeID. Returns a
/// pointer to the dialect owned by the context.
Dialect *getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
function_ref<std::unique_ptr<Dialect>()> ctor);

MLIRContext(const MLIRContext &) = delete;
void operator=(const MLIRContext &) = delete;
};
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@

using namespace mlir;

avx512::AVX512Dialect::AVX512Dialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
void avx512::AVX512Dialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/AVX512/AVX512.cpp.inc"
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ struct AffineInlinerInterface : public DialectInlinerInterface {
// AffineDialect
//===----------------------------------------------------------------------===//

AffineDialect::AffineDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
void AffineDialect::initialize() {
addOperations<AffineDmaStartOp, AffineDmaWaitOp,
#define GET_OP_LIST
#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ bool GPUDialect::isKernel(Operation *op) {
return static_cast<bool>(isKernelAttr);
}

GPUDialect::GPUDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
void GPUDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

using namespace mlir;

LLVM::LLVMAVX512Dialect::LLVMAVX512Dialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
void LLVM::LLVMAVX512Dialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc"
Expand Down
7 changes: 3 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1683,9 +1683,8 @@ struct LLVMDialectImpl {
} // end namespace LLVM
} // end namespace mlir

LLVMDialect::LLVMDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context),
impl(new detail::LLVMDialectImpl()) {
void LLVMDialect::initialize() {
impl = new detail::LLVMDialectImpl();
// clang-format off
addTypes<LLVMVoidType,
LLVMHalfType,
Expand Down Expand Up @@ -1716,7 +1715,7 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
allowUnknownOperations();
}

LLVMDialect::~LLVMDialect() {}
LLVMDialect::~LLVMDialect() { delete impl; }

#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ static LogicalResult verify(MmaOp op) {
//===----------------------------------------------------------------------===//

// TODO: This should be the llvm.nvvm dialect once this is supported.
NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) {
void NVVMDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ static ParseResult parseROCDLMubufStoreOp(OpAsmParser &parser,
//===----------------------------------------------------------------------===//

// TODO: This should be the llvm.rocdl dialect once this is supported.
ROCDLDialect::ROCDLDialect(MLIRContext *context) : Dialect("rocdl", context) {
void ROCDLDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
using namespace mlir;
using namespace mlir::linalg;

mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
void mlir::linalg::LinalgDialect::initialize() {
addTypes<RangeType>();
addOperations<
#define GET_OP_LIST
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
using namespace mlir;
using namespace mlir::omp;

OpenMPDialect::OpenMPDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
void OpenMPDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Quant/IR/QuantOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ using namespace mlir;
using namespace mlir::quant;
using namespace mlir::quant::detail;

QuantizationDialect::QuantizationDialect(MLIRContext *context)
: Dialect(/*name=*/"quant", context) {
void QuantizationDialect::initialize() {
addTypes<AnyQuantizedType, UniformQuantizedType,
UniformQuantizedPerAxisType>();
addOperations<
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/SCF/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ struct SCFInlinerInterface : public DialectInlinerInterface {
// SCFDialect
//===----------------------------------------------------------------------===//

SCFDialect::SCFDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
void SCFDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/SCF/SCFOps.cpp.inc"
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
// SPIR-V Dialect
//===----------------------------------------------------------------------===//

SPIRVDialect::SPIRVDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
void SPIRVDialect::initialize() {
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
PointerType, RuntimeArrayType, StructType>();

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
return success();
}

ShapeDialect::ShapeDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
void ShapeDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ static LogicalResult verifyCastOp(T op) {
return success();
}

StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
void StandardOpsDialect::initialize() {
addOperations<DmaStartOp, DmaWaitOp,
#define GET_OP_LIST
#include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Vector/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ using namespace mlir::vector;
// VectorDialect
//===----------------------------------------------------------------------===//

VectorDialect::VectorDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
void VectorDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Vector/VectorOps.cpp.inc"
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,9 @@ void mlir::registerAllDialects(MLIRContext *context) {
// Dialect
//===----------------------------------------------------------------------===//

Dialect::Dialect(StringRef name, MLIRContext *context)
: name(name), context(context) {
Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
: name(name), dialectID(id), context(context) {
assert(isValidNamespace(name) && "invalid dialect namespace");
registerDialect(context);
}

Dialect::~Dialect() {}
Expand Down
Loading

0 comments on commit 575b22b

Please sign in to comment.