From 1f305737c7a24861933935f6ea79ed15b6e32f52 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sun, 12 Jan 2025 18:30:42 +0000 Subject: [PATCH] Revert "Added free-threading CPython mode support in MLIR Python bindings (#107103)" Breaks on 3.8, rolling back to avoid breakage while fixing. This reverts commit 9dee7c44491635ec9037b90050bcdbd3d5291e38. --- mlir/cmake/modules/AddMLIRPython.cmake | 21 +- mlir/docs/Bindings/Python.md | 40 -- mlir/lib/Bindings/Python/Globals.h | 12 +- mlir/lib/Bindings/Python/IRCore.cpp | 31 +- mlir/lib/Bindings/Python/IRModule.cpp | 18 +- mlir/lib/Bindings/Python/IRModule.h | 1 - mlir/lib/Bindings/Python/MainModule.cpp | 9 +- mlir/python/requirements.txt | 2 +- mlir/test/python/multithreaded_tests.py | 531 ------------------------ 9 files changed, 16 insertions(+), 649 deletions(-) delete mode 100644 mlir/test/python/multithreaded_tests.py diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake index 0679db9cf93e19..717a503468a85d 100644 --- a/mlir/cmake/modules/AddMLIRPython.cmake +++ b/mlir/cmake/modules/AddMLIRPython.cmake @@ -668,31 +668,12 @@ function(add_mlir_python_extension libname extname) elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind") nanobind_add_module(${libname} NB_DOMAIN mlir - FREE_THREADED ${ARG_SOURCES} ) if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL) # Avoids warnings from upstream nanobind. - set(nanobind_target "nanobind-static") - if (NOT TARGET ${nanobind_target}) - # Get correct nanobind target name: nanobind-static-ft or something else - # It is set by nanobind_add_module function according to the passed options - get_property(all_targets DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY BUILDSYSTEM_TARGETS) - - # Iterate over the list of targets - foreach(target ${all_targets}) - # Check if the target name matches the given string - if("${target}" MATCHES "nanobind-") - set(nanobind_target "${target}") - endif() - endforeach() - - if (NOT TARGET ${nanobind_target}) - message(FATAL_ERROR "Could not find nanobind target to set compile options to") - endif() - endif() - target_compile_options(${nanobind_target} + target_compile_options(nanobind-static PRIVATE -Wno-cast-qual -Wno-zero-length-array diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md index b8bd0f507a5108..32df3310d811d7 100644 --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -1187,43 +1187,3 @@ or nanobind and utilities to connect to the rest of Python API. The bindings can be located in a separate module or in the same module as attributes and types, and loaded along with the dialect. - -## Free-threading (No-GIL) support - -Free-threading or no-GIL support refers to CPython interpreter (>=3.13) with Global Interpreter Lock made optional. For details on the topic, please check [PEP-703](https://peps.python.org/pep-0703/) and this [Python free-threading guide](https://py-free-threading.github.io/). - -MLIR Python bindings are free-threading compatible with exceptions (discussed below) in the following sense: it is safe to work in multiple threads with **independent** contexts. Below we show an example code of safe usage: - -```python -# python3.13t example.py -import concurrent.futures - -import mlir.dialects.arith as arith -from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint - - -def func(py_value): - with Context() as ctx: - module = Module.create(loc=Location.file("foo.txt", 0, 0)) - - dtype = IntegerType.get_signless(64) - with InsertionPoint(module.body), Location.name("a"): - arith.constant(dtype, py_value) - - return module - - -num_workers = 8 -with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [] - for i in range(num_workers): - futures.append(executor.submit(func, i)) - assert len(list(f.result() for f in futures)) == num_workers -``` - -The exceptions to the free-threading compatibility: -- IR printing is unsafe, e.g. when using `PassManager` with `PassManager.enable_ir_printing()` which calls thread-unsafe `llvm::raw_ostream`. -- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`). -- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`). -- Usage of `mlir.dialects.transform.interpreter` is unsafe. -- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe. \ No newline at end of file diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 826a34a5351765..0ec522d14f74bd 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -24,7 +24,6 @@ namespace mlir { namespace python { /// Globals that are always accessible once the extension has been initialized. -/// Methods of this class are thread-safe. class PyGlobals { public: PyGlobals(); @@ -38,18 +37,12 @@ class PyGlobals { /// Get and set the list of parent modules to search for dialect /// implementation classes. - std::vector getDialectSearchPrefixes() { - nanobind::ft_lock_guard lock(mutex); + std::vector &getDialectSearchPrefixes() { return dialectSearchPrefixes; } void setDialectSearchPrefixes(std::vector newValues) { - nanobind::ft_lock_guard lock(mutex); dialectSearchPrefixes.swap(newValues); } - void addDialectSearchPrefix(std::string value) { - nanobind::ft_lock_guard lock(mutex); - dialectSearchPrefixes.push_back(std::move(value)); - } /// Loads a python module corresponding to the given dialect namespace. /// No-ops if the module has already been loaded or is not found. Raises @@ -116,9 +109,6 @@ class PyGlobals { private: static PyGlobals *instance; - - nanobind::ft_mutex mutex; - /// Module name prefixes to search under for dialect implementation modules. std::vector dialectSearchPrefixes; /// Map of dialect namespace to external dialect class object. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 463ebdebb3f3f6..453d4f7c7e8bca 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -243,15 +243,9 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes, /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { - static void set(nb::object &o, bool enable) { - nb::ft_lock_guard lock(mutex); - mlirEnableGlobalDebug(enable); - } + static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); } - static bool get(const nb::object &) { - nb::ft_lock_guard lock(mutex); - return mlirIsGlobalDebugEnabled(); - } + static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); } static void bind(nb::module_ &m) { // Debug flags. @@ -261,7 +255,6 @@ struct PyGlobalDebugFlag { .def_static( "set_types", [](const std::string &type) { - nb::ft_lock_guard lock(mutex); mlirSetGlobalDebugType(type.c_str()); }, "types"_a, "Sets specific debug types to be produced by LLVM") @@ -270,17 +263,11 @@ struct PyGlobalDebugFlag { pointers.reserve(types.size()); for (const std::string &str : types) pointers.push_back(str.c_str()); - nb::ft_lock_guard lock(mutex); mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); }); } - -private: - static nb::ft_mutex mutex; }; -nb::ft_mutex PyGlobalDebugFlag::mutex; - struct PyAttrBuilderMap { static bool dunderContains(const std::string &attributeKind) { return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); @@ -619,7 +606,6 @@ class PyOpOperandIterator { PyMlirContext::PyMlirContext(MlirContext context) : context(context) { nb::gil_scoped_acquire acquire; - nb::ft_lock_guard lock(live_contexts_mutex); auto &liveContexts = getLiveContexts(); liveContexts[context.ptr] = this; } @@ -629,10 +615,7 @@ PyMlirContext::~PyMlirContext() { // forContext method, which always puts the associated handle into // liveContexts. nb::gil_scoped_acquire acquire; - { - nb::ft_lock_guard lock(live_contexts_mutex); - getLiveContexts().erase(context.ptr); - } + getLiveContexts().erase(context.ptr); mlirContextDestroy(context); } @@ -649,7 +632,6 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) { PyMlirContextRef PyMlirContext::forContext(MlirContext context) { nb::gil_scoped_acquire acquire; - nb::ft_lock_guard lock(live_contexts_mutex); auto &liveContexts = getLiveContexts(); auto it = liveContexts.find(context.ptr); if (it == liveContexts.end()) { @@ -665,17 +647,12 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) { return PyMlirContextRef(it->second, std::move(pyRef)); } -nb::ft_mutex PyMlirContext::live_contexts_mutex; - PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { static LiveContextMap liveContexts; return liveContexts; } -size_t PyMlirContext::getLiveCount() { - nb::ft_lock_guard lock(live_contexts_mutex); - return getLiveContexts().size(); -} +size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index e600f1bbd44932..f7bf77e5a7e043 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -38,11 +38,8 @@ PyGlobals::PyGlobals() { PyGlobals::~PyGlobals() { instance = nullptr; } bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { - { - nb::ft_lock_guard lock(mutex); - if (loadedDialectModules.contains(dialectNamespace)) - return true; - } + if (loadedDialectModules.contains(dialectNamespace)) + return true; // Since re-entrancy is possible, make a copy of the search prefixes. std::vector localSearchPrefixes = dialectSearchPrefixes; nb::object loaded = nb::none(); @@ -65,14 +62,12 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { return false; // Note: Iterator cannot be shared from prior to loading, since re-entrancy // may have occurred, which may do anything. - nb::ft_lock_guard lock(mutex); loadedDialectModules.insert(dialectNamespace); return true; } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, nb::callable pyFunc, bool replace) { - nb::ft_lock_guard lock(mutex); nb::object &found = attributeBuilderMap[attributeKind]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + @@ -86,7 +81,6 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, nb::callable typeCaster, bool replace) { - nb::ft_lock_guard lock(mutex); nb::object &found = typeCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Type caster is already registered with caster: " + @@ -96,7 +90,6 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, nb::callable valueCaster, bool replace) { - nb::ft_lock_guard lock(mutex); nb::object &found = valueCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Value caster is already registered: " + @@ -106,7 +99,6 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, nb::object pyClass) { - nb::ft_lock_guard lock(mutex); nb::object &found = dialectClassMap[dialectNamespace]; if (found) { throw std::runtime_error((llvm::Twine("Dialect namespace '") + @@ -118,7 +110,6 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, void PyGlobals::registerOperationImpl(const std::string &operationName, nb::object pyClass, bool replace) { - nb::ft_lock_guard lock(mutex); nb::object &found = operationClassMap[operationName]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Operation '") + operationName + @@ -130,7 +121,6 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, std::optional PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { - nb::ft_lock_guard lock(mutex); const auto foundIt = attributeBuilderMap.find(attributeKind); if (foundIt != attributeBuilderMap.end()) { assert(foundIt->second && "attribute builder is defined"); @@ -143,7 +133,6 @@ std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); - nb::ft_lock_guard lock(mutex); const auto foundIt = typeCasterMap.find(mlirTypeID); if (foundIt != typeCasterMap.end()) { assert(foundIt->second && "type caster is defined"); @@ -156,7 +145,6 @@ std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); - nb::ft_lock_guard lock(mutex); const auto foundIt = valueCasterMap.find(mlirTypeID); if (foundIt != valueCasterMap.end()) { assert(foundIt->second && "value caster is defined"); @@ -170,7 +158,6 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { // Make sure dialect module is loaded. if (!loadDialectModule(dialectNamespace)) return std::nullopt; - nb::ft_lock_guard lock(mutex); const auto foundIt = dialectClassMap.find(dialectNamespace); if (foundIt != dialectClassMap.end()) { assert(foundIt->second && "dialect class is defined"); @@ -188,7 +175,6 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) { if (!loadDialectModule(dialectNamespace)) return std::nullopt; - nb::ft_lock_guard lock(mutex); auto foundIt = operationClassMap.find(operationName); if (foundIt != operationClassMap.end()) { assert(foundIt->second && "OpView is defined"); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index f5fbb6c61b57e2..8fb32a225e65f1 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -260,7 +260,6 @@ class PyMlirContext { // Note that this holds a handle, which does not imply ownership. // Mappings will be removed when the context is destructed. using LiveContextMap = llvm::DenseMap; - static nanobind::ft_mutex live_contexts_mutex; static LiveContextMap &getLiveContexts(); // Interns all live modules associated with this context. Modules tracked diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 6f49431006605a..7c4064262012ef 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -30,8 +30,12 @@ NB_MODULE(_mlir, m) { .def_prop_rw("dialect_search_modules", &PyGlobals::getDialectSearchPrefixes, &PyGlobals::setDialectSearchPrefixes) - .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix, - "module_name"_a) + .def( + "append_dialect_search_prefix", + [](PyGlobals &self, std::string moduleName) { + self.getDialectSearchPrefixes().push_back(std::move(moduleName)); + }, + "module_name"_a) .def( "_check_dialect_module_loaded", [](PyGlobals &self, const std::string &dialectNamespace) { @@ -72,6 +76,7 @@ NB_MODULE(_mlir, m) { nanobind::cast(opClass.attr("OPERATION_NAME")); PyGlobals::get().registerOperationImpl(operationName, opClass, replace); + // Dict-stuff the new opClass by name onto the dialect class. nb::object opClassName = opClass.attr("__name__"); dialectClass.attr(opClassName) = opClass; diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 259e679f510f70..f240d6ef944ec7 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -2,4 +2,4 @@ nanobind>=2.4, <3.0 numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 -ml_dtypes>=0.5.0, <=0.6.0 # provides several NumPy dtype extensions, including the bf16 +ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16 diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py deleted file mode 100644 index 2df75e2e1b90ca..00000000000000 --- a/mlir/test/python/multithreaded_tests.py +++ /dev/null @@ -1,531 +0,0 @@ -# RUN: %PYTHON %s -""" -This script generates multi-threaded tests to check free-threading mode using CPython compiled with TSAN. -Tests can be run using pytest: -```bash -python3.13t -mpytest -vvv multithreaded_tests.py -``` - -IMPORTANT. Running tests are not checking the correctness, but just the execution of the tests in multi-threaded context -and passing if no warnings reported by TSAN and failing otherwise. - - -Details on the generated tests and execution: -1) Multi-threaded execution: all generated tests are executed independently by -a pool of threads, running each test multiple times, see @multi_threaded for details - -2) Tests generation: we use existing tests: test/python/ir/*.py, -test/python/dialects/*.py, etc to generate multi-threaded tests. -In details, we perform the following: -a) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`. -b) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method. -c) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py. -In order to import the test file as python module, we remove all executing functions, like -`@run` or `run(testMethod)`. See `copy_and_update` and `add_existing_tests` methods for details. - - -Observed warnings reported by TSAN. - -CPython and free-threading known data-races: -1) ctypes related races: https://github.com/python/cpython/issues/127945 -2) LLVM related data-races, llvm::raw_ostream is not thread-safe -- mlir pass manager -- dialects/transform_interpreter.py -- ir/diagnostic_handler.py -- ir/module.py -3) Dialect gpu module-to-binary method is unsafe -""" -import concurrent.futures -import gc -import importlib.util -import os -import sys -import threading -import tempfile -import unittest - -from contextlib import contextmanager -from functools import partial -from pathlib import Path -from typing import Optional - -import mlir.dialects.arith as arith -from mlir.dialects import transform -from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint - - -def import_from_path(module_name: str, file_path: Path): - spec = importlib.util.spec_from_file_location(module_name, file_path) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - return module - - -def copy_and_update(src_filepath: Path, dst_filepath: Path): - # We should remove all calls like `run(testMethod)` - with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer: - while True: - src_line = reader.readline() - if len(src_line) == 0: - break - skip_lines = [ - "run(", - "@run", - "@constructAndPrintInModule", - "run_apply_patterns(", - "@run_apply_patterns", - "@test_in_context", - "@construct_and_print_in_module", - ] - if any(src_line.startswith(line) for line in skip_lines): - continue - writer.write(src_line) - - -# Helper run functions -# They are copied from the test modules (e.g. run function in execution_engine.py) -def run(test_function): - # Generic run tests function used by dialects and ir test modules - test_function() - - -def run_with_context_and_location(test_function): - # run tests function with a context and a location - # used by the following test modules: - # - dialects/transform_gpu_ext, - # - dialects/vector - # - dialects/gpu/* - with Context(), Location.unknown(): - test_function() - return test_function - - -def run_with_insertion_point_and_context_arg(test_function): - # run tests function used by dialects/index_dialect test module - with Context() as ctx, Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - test_function(ctx) - - -def run_with_insertion_point(test_function): - # Used by a lot of dialects test modules - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - test_function() - return test_function - - -def run_with_insertion_point_and_module_arg(test_function): - # Used by dialects/transform test module - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - test_function(module) - return test_function - - -def run_with_insertion_point_all_unreg_dialects(test_function): - # Used by dialects/cf test module - with Context() as ctx, Location.unknown(): - ctx.allow_unregistered_dialects = True - module = Module.create() - with InsertionPoint(module.body): - test_function() - return test_function - - -def run_apply_patterns(test_function): - # Used by dialects/transform_tensor_ext test module - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, - [], - transform.AnyOpType.get(), - ) - with InsertionPoint(sequence.body): - apply = transform.ApplyPatternsOp(sequence.bodyTarget) - with InsertionPoint(apply.patterns): - test_function() - transform.YieldOp() - print(module) - return test_function - - -def run_transform_tensor_ext(test_function): - # Used by test modules: - # - dialects/transform_gpu_ext - # - dialects/transform_sparse_tensor_ext - # - dialects/transform_tensor_ext - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, - [], - transform.AnyOpType.get(), - ) - with InsertionPoint(sequence.body): - test_function(sequence.bodyTarget) - transform.YieldOp() - print(module) - return test_function - - -def run_transform_structured_ext(test_function): - # Used by dialects/transform_structured_ext test module - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - test_function() - module.operation.verify() - print(module) - return test_function - - -def run_construct_and_print_in_module(test_function): - # Used by test modules: - # - integration/dialects/pdl - # - integration/dialects/transform - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - module = test_function(module) - if module is not None: - print(module) - return test_function - - -TEST_MODULES = [ - ("execution_engine", run), - ("pass_manager", run), - ("dialects/affine", run_with_insertion_point), - ("dialects/func", run_with_insertion_point), - ("dialects/arith_dialect", run), - ("dialects/arith_llvm", run), - ("dialects/async_dialect", run), - ("dialects/builtin", run), - ("dialects/cf", run_with_insertion_point_all_unreg_dialects), - ("dialects/complex_dialect", run), - ("dialects/func", run_with_insertion_point), - ("dialects/index_dialect", run_with_insertion_point_and_context_arg), - ("dialects/llvm", run_with_insertion_point), - ("dialects/math_dialect", run), - ("dialects/memref", run), - ("dialects/ml_program", run_with_insertion_point), - ("dialects/nvgpu", run_with_insertion_point), - ("dialects/nvvm", run_with_insertion_point), - ("dialects/ods_helpers", run), - ("dialects/openmp_ops", run_with_insertion_point), - ("dialects/pdl_ops", run_with_insertion_point), - # ("dialects/python_test", run), # TODO: Need to pass pybind11 or nanobind argv - ("dialects/quant", run), - ("dialects/rocdl", run_with_insertion_point), - ("dialects/scf", run_with_insertion_point), - ("dialects/shape", run), - ("dialects/spirv_dialect", run), - ("dialects/tensor", run), - # ("dialects/tosa", ), # Nothing to test - ("dialects/transform_bufferization_ext", run_with_insertion_point), - # ("dialects/transform_extras", ), # Needs a more complicated execution schema - ("dialects/transform_gpu_ext", run_transform_tensor_ext), - ( - "dialects/transform_interpreter", - run_with_context_and_location, - ["print_", "transform_options", "failed", "include"], - ), - ( - "dialects/transform_loop_ext", - run_with_insertion_point, - ["loopOutline"], - ), - ("dialects/transform_memref_ext", run_with_insertion_point), - ("dialects/transform_nvgpu_ext", run_with_insertion_point), - ("dialects/transform_sparse_tensor_ext", run_transform_tensor_ext), - ("dialects/transform_structured_ext", run_transform_structured_ext), - ("dialects/transform_tensor_ext", run_transform_tensor_ext), - ( - "dialects/transform_vector_ext", - run_apply_patterns, - ["configurable_patterns"], - ), - ("dialects/transform", run_with_insertion_point_and_module_arg), - ("dialects/vector", run_with_context_and_location), - ("dialects/gpu/dialect", run_with_context_and_location), - ("dialects/gpu/module-to-binary-nvvm", run_with_context_and_location), - ("dialects/gpu/module-to-binary-rocdl", run_with_context_and_location), - ("dialects/linalg/ops", run), - # TO ADD: No proper tests in this dialects/linalg/opsdsl/* - # ("dialects/linalg/opsdsl/*", ...), - ("dialects/sparse_tensor/dialect", run), - ("dialects/sparse_tensor/passes", run), - ("integration/dialects/pdl", run_construct_and_print_in_module), - ("integration/dialects/transform", run_construct_and_print_in_module), - ("integration/dialects/linalg/opsrun", run), - ("ir/affine_expr", run), - ("ir/affine_map", run), - ("ir/array_attributes", run), - ("ir/attributes", run), - ("ir/blocks", run), - ("ir/builtin_types", run), - ("ir/context_managers", run), - ("ir/debug", run), - ("ir/diagnostic_handler", run), - ("ir/dialects", run), - ("ir/exception", run), - ("ir/insertion_point", run), - ("ir/integer_set", run), - ("ir/location", run), - ("ir/module", run), - ("ir/operation", run), - ("ir/symbol_table", run), - ("ir/value", run), -] - -TESTS_TO_SKIP = [ - "test_execution_engine__testNanoTime_multi_threaded", # testNanoTime can't run in multiple threads, even with GIL - "test_execution_engine__testSharedLibLoad_multi_threaded", # testSharedLibLoad can't run in multiple threads, even with GIL - "test_dialects_arith_dialect__testArithValue_multi_threaded", # RuntimeError: Value caster is already registered: .ArithValue'>, even with GIL - "test_ir_dialects__testAppendPrefixSearchPath_multi_threaded", # PyGlobals::setDialectSearchPrefixes is not thread-safe, even with GIL. Strange usage of static PyGlobals vs python exposed _cext.globals - "test_ir_value__testValueCasters_multi_threaded", # RuntimeError: Value caster is already registered: .dont_cast_int, even with GIL - # tests indirectly calling thread-unsafe llvm::raw_ostream - "test_execution_engine__testInvalidModule_multi_threaded", # mlirExecutionEngineCreate calls thread-unsafe llvm::raw_ostream - "test_pass_manager__testPrintIrAfterAll_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream - "test_pass_manager__testPrintIrBeforeAndAfterAll_multi_threaded", # IRPrinterInstrumentation::runBeforePass calls thread-unsafe llvm::raw_ostream - "test_pass_manager__testPrintIrLargeLimitElements_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream - "test_pass_manager__testPrintIrTree_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream - "test_pass_manager__testRunPipeline_multi_threaded", # PrintOpStatsPass::printSummary calls thread-unsafe llvm::raw_ostream - "test_dialects_transform_interpreter__include_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream - "test_dialects_transform_interpreter__transform_options_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream - "test_dialects_transform_interpreter__print_self_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) call thread-unsafe llvm::raw_ostream - "test_ir_diagnostic_handler__testDiagnosticCallbackException_multi_threaded", # mlirEmitError calls thread-unsafe llvm::raw_ostream - "test_ir_module__testParseSuccess_multi_threaded", # mlirOperationDump calls thread-unsafe llvm::raw_ostream - # False-positive TSAN detected race in llvm::RuntimeDyldELF::registerEHFrames() - # Details: https://github.com/llvm/llvm-project/pull/107103/files#r1905726947 - "test_execution_engine__testCapsule_multi_threaded", - "test_execution_engine__testDumpToObjectFile_multi_threaded", -] - -TESTS_TO_XFAIL = [ - # execution_engine tests: - # - ctypes related data-races: https://github.com/python/cpython/issues/127945 - "test_execution_engine__testBF16Memref_multi_threaded", - "test_execution_engine__testBasicCallback_multi_threaded", - "test_execution_engine__testComplexMemrefAdd_multi_threaded", - "test_execution_engine__testComplexUnrankedMemrefAdd_multi_threaded", - "test_execution_engine__testDynamicMemrefAdd2D_multi_threaded", - "test_execution_engine__testF16MemrefAdd_multi_threaded", - "test_execution_engine__testF8E5M2Memref_multi_threaded", - "test_execution_engine__testInvokeFloatAdd_multi_threaded", - "test_execution_engine__testInvokeVoid_multi_threaded", # a ctypes race - "test_execution_engine__testMemrefAdd_multi_threaded", - "test_execution_engine__testRankedMemRefCallback_multi_threaded", - "test_execution_engine__testRankedMemRefWithOffsetCallback_multi_threaded", - "test_execution_engine__testUnrankedMemRefCallback_multi_threaded", - "test_execution_engine__testUnrankedMemRefWithOffsetCallback_multi_threaded", - # dialects tests - "test_dialects_memref__testSubViewOpInferReturnTypeExtensiveSlicing_multi_threaded", # Related to ctypes data races - "test_dialects_transform_interpreter__print_other_multi_threaded", # Fatal Python error: Aborted or mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe - "test_dialects_gpu_module-to-binary-rocdl__testGPUToASMBin_multi_threaded", # Due to global llvm-project/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp::GCNTrackers variable mutation - "test_dialects_gpu_module-to-binary-nvvm__testGPUToASMBin_multi_threaded", - "test_dialects_gpu_module-to-binary-nvvm__testGPUToLLVMBin_multi_threaded", - "test_dialects_gpu_module-to-binary-rocdl__testGPUToLLVMBin_multi_threaded", - # integration tests - "test_integration_dialects_linalg_opsrun__test_elemwise_builtin_multi_threaded", # Related to ctypes data races - "test_integration_dialects_linalg_opsrun__test_elemwise_generic_multi_threaded", # Related to ctypes data races - "test_integration_dialects_linalg_opsrun__test_fill_builtin_multi_threaded", # ctypes - "test_integration_dialects_linalg_opsrun__test_fill_generic_multi_threaded", # ctypes - "test_integration_dialects_linalg_opsrun__test_fill_rng_builtin_multi_threaded", # ctypes - "test_integration_dialects_linalg_opsrun__test_fill_rng_generic_multi_threaded", # ctypes - "test_integration_dialects_linalg_opsrun__test_max_pooling_builtin_multi_threaded", # ctypes - "test_integration_dialects_linalg_opsrun__test_max_pooling_generic_multi_threaded", # ctypes - "test_integration_dialects_linalg_opsrun__test_min_pooling_builtin_multi_threaded", # ctypes - "test_integration_dialects_linalg_opsrun__test_min_pooling_generic_multi_threaded", # ctypes -] - - -def add_existing_tests(test_modules, test_prefix: str = "_original_test"): - def decorator(test_cls): - this_folder = Path(__file__).parent.absolute() - test_cls.output_folder = tempfile.TemporaryDirectory() - output_folder = Path(test_cls.output_folder.name) - - for test_mod_info in test_modules: - # test_mod_info is a tuple of size 2 or 3: - # (test_module_str, run_test_function) or (test_module_str, run_test_function, test_name_patterns_list) - # For example: - # - ("ir/value", run) or - # - ("dialects/transform_loop_ext", run_with_insertion_point, ["loopOutline"]) - assert isinstance(test_mod_info, tuple) and len(test_mod_info) in (2, 3) - if len(test_mod_info) == 2: - test_module_name, exec_fn = test_mod_info - test_pattern = None - else: - test_module_name, exec_fn, test_pattern = test_mod_info - - src_filepath = this_folder / f"{test_module_name}.py" - dst_filepath = (output_folder / f"{test_module_name}.py").absolute() - if not dst_filepath.parent.exists(): - dst_filepath.parent.mkdir(parents=True) - copy_and_update(src_filepath, dst_filepath) - test_mod = import_from_path(test_module_name, dst_filepath) - for attr_name in dir(test_mod): - is_test_fn = test_pattern is None and attr_name.startswith("test") - is_test_fn |= test_pattern is not None and any( - [p in attr_name for p in test_pattern] - ) - if is_test_fn: - obj = getattr(test_mod, attr_name) - if callable(obj): - test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}" - - def wrapped_test_fn( - self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs - ): - __exec_fn__(__test_fn__) - - setattr(test_cls, test_name, wrapped_test_fn) - return test_cls - - return decorator - - -@contextmanager -def _capture_output(fp): - # Inspired from jax test_utils.py capture_stderr method - # ``None`` means nothing has not been captured yet. - captured = None - - def get_output() -> str: - if captured is None: - raise ValueError("get_output() called while the context is active.") - return captured - - with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as f: - original_fd = os.dup(fp.fileno()) - os.dup2(f.fileno(), fp.fileno()) - try: - yield get_output - finally: - # Python also has its own buffers, make sure everything is flushed. - fp.flush() - os.fsync(fp.fileno()) - f.seek(0) - captured = f.read() - os.dup2(original_fd, fp.fileno()) - - -capture_stdout = partial(_capture_output, sys.stdout) -capture_stderr = partial(_capture_output, sys.stderr) - - -def multi_threaded( - num_workers: int, - num_runs: int = 5, - skip_tests: Optional[list[str]] = None, - xfail_tests: Optional[list[str]] = None, - test_prefix: str = "_original_test", - multithreaded_test_postfix: str = "_multi_threaded", -): - """Decorator that runs a test in a multi-threaded environment.""" - - def decorator(test_cls): - for name, test_fn in test_cls.__dict__.copy().items(): - if not (name.startswith(test_prefix) and callable(test_fn)): - continue - - name = f"test{name[len(test_prefix):]}" - if skip_tests is not None: - if any( - test_name.replace(multithreaded_test_postfix, "") in name - for test_name in skip_tests - ): - continue - - def multi_threaded_test_fn(self, *args, __test_fn__=test_fn, **kwargs): - with capture_stdout(), capture_stderr() as get_output: - barrier = threading.Barrier(num_workers) - - def closure(): - barrier.wait() - for _ in range(num_runs): - __test_fn__(self, *args, **kwargs) - - with concurrent.futures.ThreadPoolExecutor( - max_workers=num_workers - ) as executor: - futures = [] - for _ in range(num_workers): - futures.append(executor.submit(closure)) - # We should call future.result() to re-raise an exception if test has - # failed - assert len(list(f.result() for f in futures)) == num_workers - - gc.collect() - assert Context._get_live_count() == 0 - - captured = get_output() - if len(captured) > 0 and "ThreadSanitizer" in captured: - raise RuntimeError( - f"ThreadSanitizer reported warnings:\n{captured}" - ) - - test_new_name = f"{name}{multithreaded_test_postfix}" - if xfail_tests is not None and test_new_name in xfail_tests: - multi_threaded_test_fn = unittest.expectedFailure( - multi_threaded_test_fn - ) - - setattr(test_cls, test_new_name, multi_threaded_test_fn) - - return test_cls - - return decorator - - -@multi_threaded( - num_workers=10, - num_runs=20, - skip_tests=TESTS_TO_SKIP, - xfail_tests=TESTS_TO_XFAIL, -) -@add_existing_tests(test_modules=TEST_MODULES, test_prefix="_original_test") -class TestAllMultiThreaded(unittest.TestCase): - @classmethod - def tearDownClass(cls): - if hasattr(cls, "output_folder"): - cls.output_folder.cleanup() - - def _original_test_create_context(self): - with Context() as ctx: - print(ctx._get_live_count()) - print(ctx._get_live_module_count()) - print(ctx._get_live_operation_count()) - print(ctx._get_live_operation_objects()) - print(ctx._get_context_again() is ctx) - print(ctx._clear_live_operations()) - - def _original_test_create_module_with_consts(self): - py_values = [123, 234, 345] - with Context() as ctx: - module = Module.create(loc=Location.file("foo.txt", 0, 0)) - - dtype = IntegerType.get_signless(64) - with InsertionPoint(module.body), Location.name("a"): - arith.constant(dtype, py_values[0]) - - with InsertionPoint(module.body), Location.name("b"): - arith.constant(dtype, py_values[1]) - - with InsertionPoint(module.body), Location.name("c"): - arith.constant(dtype, py_values[2]) - - -if __name__ == "__main__": - # Do not run the tests on CPython with GIL - if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled(): - unittest.main()