Skip to content

Commit

Permalink
[mlir python] Port in-tree dialects to nanobind. (#119924)
Browse files Browse the repository at this point in the history
This is a companion to #118583, although it can be landed independently
because since #117922 dialects do not have to use the same Python
binding framework as the Python core code.

This PR ports all of the in-tree dialect and pass extensions to
nanobind, with the exception of those that remain for testing pybind11
support.

This PR also:
* removes CollectDiagnosticsToStringScope from NanobindAdaptors.h. This
was overlooked in a previous PR and it is duplicated in Diagnostics.h.

---------

Co-authored-by: Jacques Pienaar <jpienaar@google.com>
  • Loading branch information
hawkinsp and jpienaar authored Dec 21, 2024
1 parent 559f080 commit 5cd4274
Show file tree
Hide file tree
Showing 35 changed files with 357 additions and 360 deletions.
12 changes: 12 additions & 0 deletions mlir/cmake/modules/AddMLIRPython.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,18 @@ function(add_mlir_python_extension libname extname)
NB_DOMAIN mlir
${ARG_SOURCES}
)

if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
# Avoids warnings from upstream nanobind.
target_compile_options(nanobind-static
PRIVATE
-Wno-cast-qual
-Wno-zero-length-array
-Wno-nested-anon-types
-Wno-c++98-compat-extra-semi
-Wno-covered-switch-default
)
endif()
endif()

# The extension itself must be compiled with RTTI and exceptions enabled.
Expand Down
12 changes: 12 additions & 0 deletions mlir/cmake/modules/MLIRDetectPythonEnv.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,17 @@ function(mlir_detect_nanobind_install)
endif()
message(STATUS "found (${PACKAGE_DIR})")
set(nanobind_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
execute_process(
COMMAND "${Python3_EXECUTABLE}"
-c "import nanobind;print(nanobind.include_dir(), end='')"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE STATUS
OUTPUT_VARIABLE PACKAGE_DIR
ERROR_QUIET)
if(NOT STATUS EQUAL "0")
message(STATUS "not found (install via 'pip install nanobind' or set nanobind_DIR)")
return()
endif()
set(nanobind_INCLUDE_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
endif()
endfunction()
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
//
//===----------------------------------------------------------------------===//

#include <nanobind/nanobind.h>

#include "Standalone-c/Dialects.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"

namespace nb = nanobind;
Expand Down
37 changes: 37 additions & 0 deletions mlir/include/mlir/Bindings/Python/Nanobind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- Nanobind.h - Trampoline header with ignored warnings ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This file is a trampoline for the nanobind headers while disabling warnings
// reported by the LLVM/MLIR build. This file avoids adding complexity build
// system side.
//===----------------------------------------------------------------------===//

#ifndef MLIR_BINDINGS_PYTHON_NANOBIND_H
#define MLIR_BINDINGS_PYTHON_NANOBIND_H

#if defined(__clang__) || defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wzero-length-array"
#pragma GCC diagnostic ignored "-Wcast-qual"
#pragma GCC diagnostic ignored "-Wnested-anon-types"
#pragma GCC diagnostic ignored "-Wc++98-compat-extra-semi"
#pragma GCC diagnostic ignored "-Wcovered-switch-default"
#endif
#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include <nanobind/stl/function.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/string_view.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/vector.h>
#if defined(__clang__) || defined(__GNUC__)
#pragma GCC diagnostic pop
#endif

#endif // MLIR_BINDINGS_PYTHON_NANOBIND_H
40 changes: 2 additions & 38 deletions mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
#ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
#define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H

#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>

#include <cstdint>

#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "llvm/ADT/Twine.h"

// Raw CAPI type casters need to be declared before use, so always include them
Expand Down Expand Up @@ -631,40 +629,6 @@ class mlir_value_subclass : public pure_subclass {

} // namespace nanobind_adaptors

/// RAII scope intercepting all diagnostics into a string. The message must be
/// checked before this goes out of scope.
class CollectDiagnosticsToStringScope {
public:
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
/*deleteUserData=*/nullptr);
}
~CollectDiagnosticsToStringScope() {
assert(errorMessage.empty() && "unchecked error message");
mlirContextDetachDiagnosticHandler(context, handlerID);
}

[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }

private:
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
auto printer = +[](MlirStringRef message, void *data) {
*static_cast<std::string *>(data) +=
llvm::StringRef(message.data, message.length);
};
MlirLocation loc = mlirDiagnosticGetLocation(diag);
*static_cast<std::string *>(data) += "at ";
mlirLocationPrint(loc, printer, data);
*static_cast<std::string *>(data) += ": ";
mlirDiagnosticPrint(diag, printer, data);
return mlirLogicalResultSuccess();
}

MlirContext context;
MlirDiagnosticHandlerID handlerID;
std::string errorMessage = "";
};

} // namespace python
} // namespace mlir

Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Bindings/Python/AsyncPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@

#include "mlir-c/Dialect/Async.h"

#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
#include "mlir/Bindings/Python/Nanobind.h"

// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------

PYBIND11_MODULE(_mlirAsyncPasses, m) {
NB_MODULE(_mlirAsyncPasses, m) {
m.doc() = "MLIR Async Dialect Passes";

// Register all Async passes on load.
Expand Down
44 changes: 22 additions & 22 deletions mlir/lib/Bindings/Python/DialectGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@
#include "mlir-c/Dialect/GPU.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"

#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
namespace nb = nanobind;
using namespace nanobind::literals;

namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::adaptors;
using namespace mlir::python::nanobind_adaptors;

// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------

PYBIND11_MODULE(_mlirDialectsGPU, m) {
NB_MODULE(_mlirDialectsGPU, m) {
m.doc() = "MLIR GPU Dialect";
//===-------------------------------------------------------------------===//
// AsyncTokenType
Expand All @@ -34,11 +34,11 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {

mlirGPUAsyncTokenType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
[](nb::object cls, MlirContext ctx) {
return cls(mlirGPUAsyncTokenTypeGet(ctx));
},
"Gets an instance of AsyncTokenType in the same context", py::arg("cls"),
py::arg("ctx") = py::none());
"Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
nb::arg("ctx").none() = nb::none());

//===-------------------------------------------------------------------===//
// ObjectAttr
Expand All @@ -47,12 +47,12 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
.def_classmethod(
"get",
[](py::object cls, MlirAttribute target, uint32_t format,
py::bytes object, std::optional<MlirAttribute> mlirObjectProps,
[](nb::object cls, MlirAttribute target, uint32_t format,
nb::bytes object, std::optional<MlirAttribute> mlirObjectProps,
std::optional<MlirAttribute> mlirKernelsAttr) {
py::buffer_info info(py::buffer(object).request());
MlirStringRef objectStrRef =
mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
MlirStringRef objectStrRef = mlirStringRefCreate(
static_cast<char *>(const_cast<void *>(object.data())),
object.size());
return cls(mlirGPUObjectAttrGetWithKernels(
mlirAttributeGetContext(target), target, format, objectStrRef,
mlirObjectProps.has_value() ? *mlirObjectProps
Expand All @@ -61,7 +61,7 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
: MlirAttribute{nullptr}));
},
"cls"_a, "target"_a, "format"_a, "object"_a,
"properties"_a = py::none(), "kernels"_a = py::none(),
"properties"_a.none() = nb::none(), "kernels"_a.none() = nb::none(),
"Gets a gpu.object from parameters.")
.def_property_readonly(
"target",
Expand All @@ -73,18 +73,18 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
"object",
[](MlirAttribute self) {
MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
return py::bytes(stringRef.data, stringRef.length);
return nb::bytes(stringRef.data, stringRef.length);
})
.def_property_readonly("properties",
[](MlirAttribute self) {
[](MlirAttribute self) -> nb::object {
if (mlirGPUObjectAttrHasProperties(self))
return py::cast(
return nb::cast(
mlirGPUObjectAttrGetProperties(self));
return py::none().cast<py::object>();
return nb::none();
})
.def_property_readonly("kernels", [](MlirAttribute self) {
.def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
if (mlirGPUObjectAttrHasKernels(self))
return py::cast(mlirGPUObjectAttrGetKernels(self));
return py::none().cast<py::object>();
return nb::cast(mlirGPUObjectAttrGetKernels(self));
return nb::none();
});
}
54 changes: 29 additions & 25 deletions mlir/lib/Bindings/Python/DialectLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"

namespace nb = nanobind;

using namespace nanobind::literals;

namespace py = pybind11;
using namespace llvm;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::adaptors;
using namespace mlir::python::nanobind_adaptors;

void populateDialectLLVMSubmodule(const pybind11::module &m) {
void populateDialectLLVMSubmodule(const nanobind::module_ &m) {

//===--------------------------------------------------------------------===//
// StructType
Expand All @@ -31,58 +35,58 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {

llvmStructType.def_classmethod(
"get_literal",
[](py::object cls, const std::vector<MlirType> &elements, bool packed,
[](nb::object cls, const std::vector<MlirType> &elements, bool packed,
MlirLocation loc) {
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));

MlirType type = mlirLLVMStructTypeLiteralGetChecked(
loc, elements.size(), elements.data(), packed);
if (mlirTypeIsNull(type)) {
throw py::value_error(scope.takeMessage());
throw nb::value_error(scope.takeMessage().c_str());
}
return cls(type);
},
"cls"_a, "elements"_a, py::kw_only(), "packed"_a = false,
"loc"_a = py::none());
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
"loc"_a.none() = nb::none());

llvmStructType.def_classmethod(
"get_identified",
[](py::object cls, const std::string &name, MlirContext context) {
[](nb::object cls, const std::string &name, MlirContext context) {
return cls(mlirLLVMStructTypeIdentifiedGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
"cls"_a, "name"_a, py::kw_only(), "context"_a = py::none());
"cls"_a, "name"_a, nb::kw_only(), "context"_a.none() = nb::none());

llvmStructType.def_classmethod(
"get_opaque",
[](py::object cls, const std::string &name, MlirContext context) {
[](nb::object cls, const std::string &name, MlirContext context) {
return cls(mlirLLVMStructTypeOpaqueGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
"cls"_a, "name"_a, "context"_a = py::none());
"cls"_a, "name"_a, "context"_a.none() = nb::none());

llvmStructType.def(
"set_body",
[](MlirType self, const std::vector<MlirType> &elements, bool packed) {
MlirLogicalResult result = mlirLLVMStructTypeSetBody(
self, elements.size(), elements.data(), packed);
if (!mlirLogicalResultIsSuccess(result)) {
throw py::value_error(
throw nb::value_error(
"Struct body already set to different content.");
}
},
"elements"_a, py::kw_only(), "packed"_a = false);
"elements"_a, nb::kw_only(), "packed"_a = false);

llvmStructType.def_classmethod(
"new_identified",
[](py::object cls, const std::string &name,
[](nb::object cls, const std::string &name,
const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
return cls(mlirLLVMStructTypeIdentifiedNewGet(
ctx, mlirStringRefCreate(name.data(), name.length()),
elements.size(), elements.data(), packed));
},
"cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false,
"context"_a = py::none());
"cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
"context"_a.none() = nb::none());

llvmStructType.def_property_readonly(
"name", [](MlirType type) -> std::optional<std::string> {
Expand All @@ -93,12 +97,12 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
return StringRef(stringRef.data, stringRef.length).str();
});

llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
llvmStructType.def_property_readonly("body", [](MlirType type) -> nb::object {
// Don't crash in absence of a body.
if (mlirLLVMStructTypeIsOpaque(type))
return py::none();
return nb::none();

py::list body;
nb::list body;
for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
++i) {
body.append(mlirLLVMStructTypeGetElementType(type, i));
Expand All @@ -119,24 +123,24 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
.def_classmethod(
"get",
[](py::object cls, std::optional<unsigned> addressSpace,
[](nb::object cls, std::optional<unsigned> addressSpace,
MlirContext context) {
CollectDiagnosticsToStringScope scope(context);
MlirType type = mlirLLVMPointerTypeGet(
context, addressSpace.has_value() ? *addressSpace : 0);
if (mlirTypeIsNull(type)) {
throw py::value_error(scope.takeMessage());
throw nb::value_error(scope.takeMessage().c_str());
}
return cls(type);
},
"cls"_a, "address_space"_a = py::none(), py::kw_only(),
"context"_a = py::none())
"cls"_a, "address_space"_a.none() = nb::none(), nb::kw_only(),
"context"_a.none() = nb::none())
.def_property_readonly("address_space", [](MlirType type) {
return mlirLLVMPointerTypeGetAddressSpace(type);
});
}

PYBIND11_MODULE(_mlirDialectsLLVM, m) {
NB_MODULE(_mlirDialectsLLVM, m) {
m.doc() = "MLIR LLVM Dialect";

populateDialectLLVMSubmodule(m);
Expand Down
Loading

0 comments on commit 5cd4274

Please sign in to comment.