Skip to content

Commit

Permalink
Revert "Added free-threading CPython mode support in MLIR Python bind…
Browse files Browse the repository at this point in the history
…ings (llvm#107103)"

Breaks on 3.8, rolling back to avoid breakage while fixing.

This reverts commit 9dee7c4.
  • Loading branch information
jpienaar authored and DKLoehr committed Jan 17, 2025
1 parent 90df733 commit 1f30573
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 649 deletions.
21 changes: 1 addition & 20 deletions mlir/cmake/modules/AddMLIRPython.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -668,31 +668,12 @@ function(add_mlir_python_extension libname extname)
elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind")
nanobind_add_module(${libname}
NB_DOMAIN mlir
FREE_THREADED
${ARG_SOURCES}
)

if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
# Avoids warnings from upstream nanobind.
set(nanobind_target "nanobind-static")
if (NOT TARGET ${nanobind_target})
# Get correct nanobind target name: nanobind-static-ft or something else
# It is set by nanobind_add_module function according to the passed options
get_property(all_targets DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY BUILDSYSTEM_TARGETS)

# Iterate over the list of targets
foreach(target ${all_targets})
# Check if the target name matches the given string
if("${target}" MATCHES "nanobind-")
set(nanobind_target "${target}")
endif()
endforeach()

if (NOT TARGET ${nanobind_target})
message(FATAL_ERROR "Could not find nanobind target to set compile options to")
endif()
endif()
target_compile_options(${nanobind_target}
target_compile_options(nanobind-static
PRIVATE
-Wno-cast-qual
-Wno-zero-length-array
Expand Down
40 changes: 0 additions & 40 deletions mlir/docs/Bindings/Python.md
Original file line number Diff line number Diff line change
Expand Up @@ -1187,43 +1187,3 @@ or nanobind and
utilities to connect to the rest of Python API. The bindings can be located in a
separate module or in the same module as attributes and types, and
loaded along with the dialect.

## Free-threading (No-GIL) support

Free-threading or no-GIL support refers to CPython interpreter (>=3.13) with Global Interpreter Lock made optional. For details on the topic, please check [PEP-703](https://peps.python.org/pep-0703/) and this [Python free-threading guide](https://py-free-threading.github.io/).

MLIR Python bindings are free-threading compatible with exceptions (discussed below) in the following sense: it is safe to work in multiple threads with **independent** contexts. Below we show an example code of safe usage:

```python
# python3.13t example.py
import concurrent.futures

import mlir.dialects.arith as arith
from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint


def func(py_value):
with Context() as ctx:
module = Module.create(loc=Location.file("foo.txt", 0, 0))

dtype = IntegerType.get_signless(64)
with InsertionPoint(module.body), Location.name("a"):
arith.constant(dtype, py_value)

return module


num_workers = 8
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for i in range(num_workers):
futures.append(executor.submit(func, i))
assert len(list(f.result() for f in futures)) == num_workers
```

The exceptions to the free-threading compatibility:
- IR printing is unsafe, e.g. when using `PassManager` with `PassManager.enable_ir_printing()` which calls thread-unsafe `llvm::raw_ostream`.
- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
- Usage of `mlir.dialects.transform.interpreter` is unsafe.
- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.
12 changes: 1 addition & 11 deletions mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ namespace mlir {
namespace python {

/// Globals that are always accessible once the extension has been initialized.
/// Methods of this class are thread-safe.
class PyGlobals {
public:
PyGlobals();
Expand All @@ -38,18 +37,12 @@ class PyGlobals {

/// Get and set the list of parent modules to search for dialect
/// implementation classes.
std::vector<std::string> getDialectSearchPrefixes() {
nanobind::ft_lock_guard lock(mutex);
std::vector<std::string> &getDialectSearchPrefixes() {
return dialectSearchPrefixes;
}
void setDialectSearchPrefixes(std::vector<std::string> newValues) {
nanobind::ft_lock_guard lock(mutex);
dialectSearchPrefixes.swap(newValues);
}
void addDialectSearchPrefix(std::string value) {
nanobind::ft_lock_guard lock(mutex);
dialectSearchPrefixes.push_back(std::move(value));
}

/// Loads a python module corresponding to the given dialect namespace.
/// No-ops if the module has already been loaded or is not found. Raises
Expand Down Expand Up @@ -116,9 +109,6 @@ class PyGlobals {

private:
static PyGlobals *instance;

nanobind::ft_mutex mutex;

/// Module name prefixes to search under for dialect implementation modules.
std::vector<std::string> dialectSearchPrefixes;
/// Map of dialect namespace to external dialect class object.
Expand Down
31 changes: 4 additions & 27 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,9 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes,

/// Wrapper for the global LLVM debugging flag.
struct PyGlobalDebugFlag {
static void set(nb::object &o, bool enable) {
nb::ft_lock_guard lock(mutex);
mlirEnableGlobalDebug(enable);
}
static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); }

static bool get(const nb::object &) {
nb::ft_lock_guard lock(mutex);
return mlirIsGlobalDebugEnabled();
}
static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); }

static void bind(nb::module_ &m) {
// Debug flags.
Expand All @@ -261,7 +255,6 @@ struct PyGlobalDebugFlag {
.def_static(
"set_types",
[](const std::string &type) {
nb::ft_lock_guard lock(mutex);
mlirSetGlobalDebugType(type.c_str());
},
"types"_a, "Sets specific debug types to be produced by LLVM")
Expand All @@ -270,17 +263,11 @@ struct PyGlobalDebugFlag {
pointers.reserve(types.size());
for (const std::string &str : types)
pointers.push_back(str.c_str());
nb::ft_lock_guard lock(mutex);
mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
});
}

private:
static nb::ft_mutex mutex;
};

nb::ft_mutex PyGlobalDebugFlag::mutex;

struct PyAttrBuilderMap {
static bool dunderContains(const std::string &attributeKind) {
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
Expand Down Expand Up @@ -619,7 +606,6 @@ class PyOpOperandIterator {

PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
nb::gil_scoped_acquire acquire;
nb::ft_lock_guard lock(live_contexts_mutex);
auto &liveContexts = getLiveContexts();
liveContexts[context.ptr] = this;
}
Expand All @@ -629,10 +615,7 @@ PyMlirContext::~PyMlirContext() {
// forContext method, which always puts the associated handle into
// liveContexts.
nb::gil_scoped_acquire acquire;
{
nb::ft_lock_guard lock(live_contexts_mutex);
getLiveContexts().erase(context.ptr);
}
getLiveContexts().erase(context.ptr);
mlirContextDestroy(context);
}

Expand All @@ -649,7 +632,6 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) {

PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
nb::gil_scoped_acquire acquire;
nb::ft_lock_guard lock(live_contexts_mutex);
auto &liveContexts = getLiveContexts();
auto it = liveContexts.find(context.ptr);
if (it == liveContexts.end()) {
Expand All @@ -665,17 +647,12 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
return PyMlirContextRef(it->second, std::move(pyRef));
}

nb::ft_mutex PyMlirContext::live_contexts_mutex;

PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
static LiveContextMap liveContexts;
return liveContexts;
}

size_t PyMlirContext::getLiveCount() {
nb::ft_lock_guard lock(live_contexts_mutex);
return getLiveContexts().size();
}
size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }

size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }

Expand Down
18 changes: 2 additions & 16 deletions mlir/lib/Bindings/Python/IRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,8 @@ PyGlobals::PyGlobals() {
PyGlobals::~PyGlobals() { instance = nullptr; }

bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
{
nb::ft_lock_guard lock(mutex);
if (loadedDialectModules.contains(dialectNamespace))
return true;
}
if (loadedDialectModules.contains(dialectNamespace))
return true;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
nb::object loaded = nb::none();
Expand All @@ -65,14 +62,12 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
return false;
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
nb::ft_lock_guard lock(mutex);
loadedDialectModules.insert(dialectNamespace);
return true;
}

void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
nb::callable pyFunc, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = attributeBuilderMap[attributeKind];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
Expand All @@ -86,7 +81,6 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,

void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
nb::callable typeCaster, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = typeCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Type caster is already registered with caster: " +
Expand All @@ -96,7 +90,6 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,

void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
nb::callable valueCaster, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = valueCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Value caster is already registered: " +
Expand All @@ -106,7 +99,6 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,

void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
nb::object pyClass) {
nb::ft_lock_guard lock(mutex);
nb::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
Expand All @@ -118,7 +110,6 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,

void PyGlobals::registerOperationImpl(const std::string &operationName,
nb::object pyClass, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = operationClassMap[operationName];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
Expand All @@ -130,7 +121,6 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,

std::optional<nb::callable>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
nb::ft_lock_guard lock(mutex);
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
assert(foundIt->second && "attribute builder is defined");
Expand All @@ -143,7 +133,6 @@ std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
nb::ft_lock_guard lock(mutex);
const auto foundIt = typeCasterMap.find(mlirTypeID);
if (foundIt != typeCasterMap.end()) {
assert(foundIt->second && "type caster is defined");
Expand All @@ -156,7 +145,6 @@ std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
nb::ft_lock_guard lock(mutex);
const auto foundIt = valueCasterMap.find(mlirTypeID);
if (foundIt != valueCasterMap.end()) {
assert(foundIt->second && "value caster is defined");
Expand All @@ -170,7 +158,6 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
// Make sure dialect module is loaded.
if (!loadDialectModule(dialectNamespace))
return std::nullopt;
nb::ft_lock_guard lock(mutex);
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
assert(foundIt->second && "dialect class is defined");
Expand All @@ -188,7 +175,6 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
if (!loadDialectModule(dialectNamespace))
return std::nullopt;

nb::ft_lock_guard lock(mutex);
auto foundIt = operationClassMap.find(operationName);
if (foundIt != operationClassMap.end()) {
assert(foundIt->second && "OpView is defined");
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ class PyMlirContext {
// Note that this holds a handle, which does not imply ownership.
// Mappings will be removed when the context is destructed.
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
static nanobind::ft_mutex live_contexts_mutex;
static LiveContextMap &getLiveContexts();

// Interns all live modules associated with this context. Modules tracked
Expand Down
9 changes: 7 additions & 2 deletions mlir/lib/Bindings/Python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ NB_MODULE(_mlir, m) {
.def_prop_rw("dialect_search_modules",
&PyGlobals::getDialectSearchPrefixes,
&PyGlobals::setDialectSearchPrefixes)
.def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
"module_name"_a)
.def(
"append_dialect_search_prefix",
[](PyGlobals &self, std::string moduleName) {
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
},
"module_name"_a)
.def(
"_check_dialect_module_loaded",
[](PyGlobals &self, const std::string &dialectNamespace) {
Expand Down Expand Up @@ -72,6 +76,7 @@ NB_MODULE(_mlir, m) {
nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
PyGlobals::get().registerOperationImpl(operationName, opClass,
replace);

// Dict-stuff the new opClass by name onto the dialect class.
nb::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;
Expand Down
2 changes: 1 addition & 1 deletion mlir/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ nanobind>=2.4, <3.0
numpy>=1.19.5, <=2.1.2
pybind11>=2.10.0, <=2.13.6
PyYAML>=5.4.0, <=6.0.1
ml_dtypes>=0.5.0, <=0.6.0 # provides several NumPy dtype extensions, including the bf16
ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16
Loading

0 comments on commit 1f30573

Please sign in to comment.